diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc index 4904d2fe9ebe9..fa76842d8a1bb 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc @@ -62,6 +62,10 @@ T GsReflect(T x, T x_min, T x_max) { T dx = {}; T fx = static_cast(x); T range = x_max - x_min; + if (range <= static_cast(0)) { + return x_min; + } + if (fx < x_min) { dx = x_min - fx; int n = static_cast(dx / range); @@ -123,6 +127,8 @@ T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, in } else { // (padding_mode_ == Reflection) c = static_cast(GsReflect(static_cast(c), border[0], border[2])); r = static_cast(GsReflect(static_cast(r), border[1], border[3])); + c = std::clamp(c, 0, W - 1); + r = std::clamp(r, 0, H - 1); pixel = image[r * W + c]; } return pixel; @@ -144,6 +150,9 @@ T GridSample::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, w = static_cast(GsReflect(static_cast(w), border[0], border[3])); h = static_cast(GsReflect(static_cast(h), border[1], border[4])); d = static_cast(GsReflect(static_cast(d), border[2], border[5])); + w = std::clamp(w, 0, W - 1); + h = std::clamp(h, 0, H - 1); + d = std::clamp(d, 0, D - 1); pixel = image[d * H * W + h * W + w]; } return pixel; diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu index 0e7d947741924..82008f2e9c562 100755 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu @@ -51,7 +51,7 @@ __device__ T GsReflect(T x, float x_min, float x_max) { template __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { + int64_t padding_mode, int64_t C, int64_t H, int64_t W, float border[4]) { T pixel = 0.0f; auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t { @@ -69,8 +69,23 @@ __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_ y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y)); pixel = input_data[PixelOffset(x, y)]; } else { // Reflection - x = (int64_t)GsReflect(x, border[0], border[2]); - y = (int64_t)GsReflect(y, border[1], border[3]); + // Handle degenerate size-1 dimensions explicitly to avoid division by zero + // in GsReflect when x_min == x_max (range == 0), and clamp reflected + // coordinates into valid [0, H/W) ranges. + if (W == 1) { + x = 0; + } else { + x = (int64_t)GsReflect(x, border[0], border[2]); + x = max((int64_t)0, min((int64_t)W - 1, x)); + } + + if (H == 1) { + y = 0; + } else { + y = (int64_t)GsReflect(y, border[1], border[3]); + y = max((int64_t)0, min((int64_t)H - 1, y)); + } + pixel = input_data[PixelOffset(x, y)]; } return pixel; @@ -100,8 +115,8 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) { template __global__ void _GridSampleKernel( - const T* input_data, - const T* grid_data, + const T* __restrict__ input_data, + const T* __restrict__ grid_data, const int64_t mode, const int64_t padding_mode, const int64_t align_corners, @@ -111,7 +126,7 @@ __global__ void _GridSampleKernel( const int64_t W_in, const int64_t H_out, const int64_t W_out, - T* output_data) { + T* __restrict__ output_data) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); // extract batch index, channel index, y index, x index for current thread int BIdx, yIdx, xIdx, cIdx; @@ -200,10 +215,10 @@ __global__ void _GridSampleKernel( w_lb = w_b * w_l; w_rb = w_b * w_r; - T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); - T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); - T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); - T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); + T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, C, H_in, W_in, border); + T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, C, H_in, W_in, border); + T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, C, H_in, W_in, border); + T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, C, H_in, W_in, border); T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; output_data[outIdx] = interpoV; return; @@ -212,7 +227,7 @@ __global__ void _GridSampleKernel( int x_n = grid_x_imgSpace; int y_n = grid_y_imgSpace; output_data[outIdx] = - PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); + PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, C, H_in, W_in, border); return; } if (mode == 2) { // bicubic @@ -222,7 +237,7 @@ __global__ void _GridSampleKernel( for (int64_t h = 0; h < 4; h++) { for (int64_t w = 0; w < 4; w++) { p[h][w] = - PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); + PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, C, H_in, W_in, border); } } T dx = grid_x_imgSpace - x0 - 1; @@ -263,7 +278,7 @@ SPECIALIZED_IMPL(float, true) // NHWC template __device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t z, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) { + int64_t padding_mode, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) { T pixel = 0.0f; auto PixelOffset3D = [bIdx, cIdx, C, D, H, W](int64_t z, int64_t y, int64_t x) -> int64_t { @@ -283,9 +298,29 @@ __device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int6 pixel = input_data[PixelOffset3D(z, y, x)]; } else { // Reflection - z = (int64_t)GsReflect(z, border[0], border[3]); - y = (int64_t)GsReflect(y, border[1], border[4]); - x = (int64_t)GsReflect(x, border[2], border[5]); + // Handle degenerate size-1 dimensions explicitly to avoid division by zero + // in GsReflect when x_min == x_max (range == 0), and clamp reflected + // coordinates into valid [0, D/H/W) ranges. + if (D == 1) { + z = 0; + } else { + z = (int64_t)GsReflect(z, border[0], border[3]); + z = max((int64_t)0, min((int64_t)D - 1, z)); + } + + if (H == 1) { + y = 0; + } else { + y = (int64_t)GsReflect(y, border[1], border[4]); + y = max((int64_t)0, min((int64_t)H - 1, y)); + } + + if (W == 1) { + x = 0; + } else { + x = (int64_t)GsReflect(x, border[2], border[5]); + x = max((int64_t)0, min((int64_t)W - 1, x)); + } pixel = input_data[PixelOffset3D(z, y, x)]; } @@ -433,15 +468,15 @@ __global__ void _GridSampleKernel3D( w_lb_back = w_b * w_l * w_back; w_rb_back = w_b * w_r * w_back; - T lt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border); - T rt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border); - T lb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border); - T rb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border); + T lt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, C, D_in, H_in, W_in, border); + T rt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, C, D_in, H_in, W_in, border); + T lb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, C, D_in, H_in, W_in, border); + T rb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, C, D_in, H_in, W_in, border); - T lt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border); - T rt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border); - T lb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border); - T rb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border); + T lt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x1, padding_mode, C, D_in, H_in, W_in, border); + T rt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x2, padding_mode, C, D_in, H_in, W_in, border); + T lb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x1, padding_mode, C, D_in, H_in, W_in, border); + T rb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x2, padding_mode, C, D_in, H_in, W_in, border); T interpoV = w_lt_front * lt_front_v + w_rt_front * rt_front_v + w_lb_front * lb_front_v + w_rb_front * rb_front_v + w_lt_back * lt_back_v + w_rt_back * rt_back_v + w_lb_back * lb_back_v + w_rb_back * rb_back_v; @@ -455,7 +490,7 @@ __global__ void _GridSampleKernel3D( int z_n = grid_z_volSpace; output_data[outIdx] = - PixelAtGrid3D(input_data, BIdx, cIdx, z_n, y_n, x_n, padding_mode, N, C, D_in, H_in, W_in, border); + PixelAtGrid3D(input_data, BIdx, cIdx, z_n, y_n, x_n, padding_mode, C, D_in, H_in, W_in, border); return; } if (mode == 2) { // cubic diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc index 2423d7f120b20..9db3c079f95ea 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc @@ -71,3 +71,60 @@ TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_zeros_mixed_bound test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders()); } + +TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_reflection_degenerate_spatial_dims_align_corners) { + // For a 1x1 input with align_corners=1, bilinear sampling always uses neighbors + // at indices {0, 1}. Reflection padding must map the out-of-bounds +1 neighbor + // back to index 0 instead of triggering invalid reflection math on CUDA. + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{1, 1, 1, 1}; + std::initializer_list X_data{TypeParam(3.25f)}; + std::initializer_list Grid_shape{1, 2, 2, 2}; + std::initializer_list Grid_data{ + TypeParam(-1.0f), TypeParam(-1.0f), + TypeParam(0.0f), TypeParam(0.0f), + TypeParam(0.5f), TypeParam(-0.5f), + TypeParam(1.0f), TypeParam(1.0f)}; + std::initializer_list Y_shape{1, 1, 2, 2}; + std::initializer_list Y_data{ + TypeParam(3.25f), TypeParam(3.25f), + TypeParam(3.25f), TypeParam(3.25f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + RunTests(test, GetExecutionProviders()); +} + +TYPED_TEST(GridSampleCustomTest, test_grid_sample_22_5D_linear_reflection_degenerate_spatial_dims_align_corners) { + // Same coverage for 3D GridSample: the 1x1x1 input forces reflected neighbors at + // index +1 for trilinear interpolation, which used to be unsafe on CUDA. + OpTester test("GridSample", 22); + std::string mode = "linear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{1, 1, 1, 1, 1}; + std::initializer_list X_data{TypeParam(-2.5f)}; + std::initializer_list Grid_shape{1, 2, 2, 1, 3}; + std::initializer_list Grid_data{ + TypeParam(-1.0f), TypeParam(-1.0f), TypeParam(-1.0f), + TypeParam(0.0f), TypeParam(0.0f), TypeParam(0.0f), + TypeParam(0.5f), TypeParam(-0.5f), TypeParam(1.0f), + TypeParam(1.0f), TypeParam(1.0f), TypeParam(1.0f)}; + std::initializer_list Y_shape{1, 1, 2, 2, 1}; + std::initializer_list Y_data{ + TypeParam(-2.5f), TypeParam(-2.5f), + TypeParam(-2.5f), TypeParam(-2.5f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + RunTests(test, GetExecutionProviders()); +}