Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/grid_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ T GsReflect(T x, T x_min, T x_max) {
T dx = {};
T fx = static_cast<T>(x);
T range = x_max - x_min;
if (range <= static_cast<T>(0)) {
return x_min;
}

if (fx < x_min) {
dx = x_min - fx;
int n = static_cast<int>(dx / range);
Expand Down Expand Up @@ -123,6 +127,8 @@ T GridSample<T>::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, in
} else { // (padding_mode_ == Reflection)
c = static_cast<int64_t>(GsReflect(static_cast<T>(c), border[0], border[2]));
r = static_cast<int64_t>(GsReflect(static_cast<T>(r), border[1], border[3]));
c = std::clamp<int64_t>(c, 0, W - 1);
r = std::clamp<int64_t>(r, 0, H - 1);
pixel = image[r * W + c];
}
return pixel;
Expand All @@ -144,6 +150,9 @@ T GridSample<T>::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w,
w = static_cast<int64_t>(GsReflect(static_cast<T>(w), border[0], border[3]));
h = static_cast<int64_t>(GsReflect(static_cast<T>(h), border[1], border[4]));
d = static_cast<int64_t>(GsReflect(static_cast<T>(d), border[2], border[5]));
w = std::clamp<int64_t>(w, 0, W - 1);
h = std::clamp<int64_t>(h, 0, H - 1);
d = std::clamp<int64_t>(d, 0, D - 1);
pixel = image[d * H * W + h * W + w];
}
return pixel;
Expand Down
85 changes: 60 additions & 25 deletions onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

template <typename T, bool Layout>
__device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x,
int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
int64_t padding_mode, int64_t C, int64_t H, int64_t W, float border[4]) {
T pixel = 0.0f;

auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t {
Expand All @@ -69,8 +69,23 @@
y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
pixel = input_data[PixelOffset(x, y)];
} else { // Reflection
x = (int64_t)GsReflect<T>(x, border[0], border[2]);
y = (int64_t)GsReflect<T>(y, border[1], border[3]);
// Handle degenerate size-1 dimensions explicitly to avoid division by zero
// in GsReflect when x_min == x_max (range == 0), and clamp reflected
// coordinates into valid [0, H/W) ranges.
if (W == 1) {
x = 0;
} else {
x = (int64_t)GsReflect<T>(x, border[0], border[2]);
x = max((int64_t)0, min((int64_t)W - 1, x));
}

if (H == 1) {
y = 0;
} else {
y = (int64_t)GsReflect<T>(y, border[1], border[3]);
y = max((int64_t)0, min((int64_t)H - 1, y));
}

pixel = input_data[PixelOffset(x, y)];
}
return pixel;
Expand Down Expand Up @@ -100,8 +115,8 @@

template <typename T, bool Layout>
__global__ void _GridSampleKernel(
const T* input_data,
const T* grid_data,
const T* __restrict__ input_data,
const T* __restrict__ grid_data,
const int64_t mode,
const int64_t padding_mode,
const int64_t align_corners,
Expand All @@ -111,7 +126,7 @@
const int64_t W_in,
const int64_t H_out,
const int64_t W_out,
T* output_data) {
T* __restrict__ output_data) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out);
// extract batch index, channel index, y index, x index for current thread
int BIdx, yIdx, xIdx, cIdx;
Expand Down Expand Up @@ -200,10 +215,10 @@
w_lb = w_b * w_l;
w_rb = w_b * w_r;

T lt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
T rt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
T lb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
T rb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
T lt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x1, padding_mode, C, H_in, W_in, border);
T rt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x2, padding_mode, C, H_in, W_in, border);
T lb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x1, padding_mode, C, H_in, W_in, border);
T rb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x2, padding_mode, C, H_in, W_in, border);
T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v;
output_data[outIdx] = interpoV;
return;
Expand All @@ -212,7 +227,7 @@
int x_n = grid_x_imgSpace;
int y_n = grid_y_imgSpace;
output_data[outIdx] =
PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y_n, x_n, padding_mode, C, H_in, W_in, border);
return;
}
if (mode == 2) { // bicubic
Expand All @@ -222,7 +237,7 @@
for (int64_t h = 0; h < 4; h++) {
for (int64_t w = 0; w < 4; w++) {
p[h][w] =
PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, C, H_in, W_in, border);
}
}
T dx = grid_x_imgSpace - x0 - 1;
Expand Down Expand Up @@ -263,7 +278,7 @@

template <typename T, bool Layout>
__device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t z, int64_t y, int64_t x,
int64_t padding_mode, int64_t N, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) {
int64_t padding_mode, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) {
T pixel = 0.0f;

auto PixelOffset3D = [bIdx, cIdx, C, D, H, W](int64_t z, int64_t y, int64_t x) -> int64_t {
Expand All @@ -283,9 +298,29 @@

pixel = input_data[PixelOffset3D(z, y, x)];
} else { // Reflection
z = (int64_t)GsReflect<T>(z, border[0], border[3]);
y = (int64_t)GsReflect<T>(y, border[1], border[4]);
x = (int64_t)GsReflect<T>(x, border[2], border[5]);
// Handle degenerate size-1 dimensions explicitly to avoid division by zero
// in GsReflect when x_min == x_max (range == 0), and clamp reflected
// coordinates into valid [0, D/H/W) ranges.
if (D == 1) {
z = 0;
} else {
z = (int64_t)GsReflect<T>(z, border[0], border[3]);
z = max((int64_t)0, min((int64_t)D - 1, z));
}

if (H == 1) {
y = 0;
} else {
y = (int64_t)GsReflect<T>(y, border[1], border[4]);
y = max((int64_t)0, min((int64_t)H - 1, y));
}

if (W == 1) {
x = 0;
} else {
x = (int64_t)GsReflect<T>(x, border[2], border[5]);
x = max((int64_t)0, min((int64_t)W - 1, x));

Check warning on line 322 in onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu:322: Add #include <algorithm> for min [build/include_what_you_use] [4]
}

pixel = input_data[PixelOffset3D(z, y, x)];
}
Expand Down Expand Up @@ -433,15 +468,15 @@
w_lb_back = w_b * w_l * w_back;
w_rb_back = w_b * w_r * w_back;

T lt_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T rt_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T lb_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T rb_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T lt_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, C, D_in, H_in, W_in, border);
T rt_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, C, D_in, H_in, W_in, border);
T lb_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, C, D_in, H_in, W_in, border);
T rb_front_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, C, D_in, H_in, W_in, border);

T lt_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T rt_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T lb_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T rb_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T lt_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y1, x1, padding_mode, C, D_in, H_in, W_in, border);
T rt_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y1, x2, padding_mode, C, D_in, H_in, W_in, border);
T lb_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y2, x1, padding_mode, C, D_in, H_in, W_in, border);
T rb_back_v = PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z2, y2, x2, padding_mode, C, D_in, H_in, W_in, border);

T interpoV = w_lt_front * lt_front_v + w_rt_front * rt_front_v + w_lb_front * lb_front_v + w_rb_front * rb_front_v +
w_lt_back * lt_back_v + w_rt_back * rt_back_v + w_lb_back * lb_back_v + w_rb_back * rb_back_v;
Expand All @@ -455,7 +490,7 @@
int z_n = grid_z_volSpace;

output_data[outIdx] =
PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z_n, y_n, x_n, padding_mode, N, C, D_in, H_in, W_in, border);
PixelAtGrid3D<T, Layout>(input_data, BIdx, cIdx, z_n, y_n, x_n, padding_mode, C, D_in, H_in, W_in, border);
return;
}
if (mode == 2) { // cubic
Expand Down
57 changes: 57 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,60 @@ TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_zeros_mixed_bound
test.AddOutput<TypeParam>("Y", Y_shape, Y_data);
RunTests(test, GetExecutionProviders());
}

TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_reflection_degenerate_spatial_dims_align_corners) {
// For a 1x1 input with align_corners=1, bilinear sampling always uses neighbors
// at indices {0, 1}. Reflection padding must map the out-of-bounds +1 neighbor
// back to index 0 instead of triggering invalid reflection math on CUDA.
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "reflection";
int64_t align_corners = 1;
std::initializer_list<int64_t> X_shape{1, 1, 1, 1};
std::initializer_list<TypeParam> X_data{TypeParam(3.25f)};
std::initializer_list<int64_t> Grid_shape{1, 2, 2, 2};
std::initializer_list<TypeParam> Grid_data{
TypeParam(-1.0f), TypeParam(-1.0f),
TypeParam(0.0f), TypeParam(0.0f),
TypeParam(0.5f), TypeParam(-0.5f),
TypeParam(1.0f), TypeParam(1.0f)};
std::initializer_list<int64_t> Y_shape{1, 1, 2, 2};
std::initializer_list<TypeParam> Y_data{
TypeParam(3.25f), TypeParam(3.25f),
TypeParam(3.25f), TypeParam(3.25f)};
test.AddInput<TypeParam>("X", X_shape, X_data);
test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data);
test.AddAttribute("mode", mode);
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput<TypeParam>("Y", Y_shape, Y_data);
RunTests(test, GetExecutionProviders());
}

TYPED_TEST(GridSampleCustomTest, test_grid_sample_22_5D_linear_reflection_degenerate_spatial_dims_align_corners) {
// Same coverage for 3D GridSample: the 1x1x1 input forces reflected neighbors at
// index +1 for trilinear interpolation, which used to be unsafe on CUDA.
OpTester test("GridSample", 22);
std::string mode = "linear";
std::string padding_mode = "reflection";
int64_t align_corners = 1;
std::initializer_list<int64_t> X_shape{1, 1, 1, 1, 1};
std::initializer_list<TypeParam> X_data{TypeParam(-2.5f)};
std::initializer_list<int64_t> Grid_shape{1, 2, 2, 1, 3};
std::initializer_list<TypeParam> Grid_data{
TypeParam(-1.0f), TypeParam(-1.0f), TypeParam(-1.0f),
TypeParam(0.0f), TypeParam(0.0f), TypeParam(0.0f),
TypeParam(0.5f), TypeParam(-0.5f), TypeParam(1.0f),
TypeParam(1.0f), TypeParam(1.0f), TypeParam(1.0f)};
std::initializer_list<int64_t> Y_shape{1, 1, 2, 2, 1};
std::initializer_list<TypeParam> Y_data{
TypeParam(-2.5f), TypeParam(-2.5f),
TypeParam(-2.5f), TypeParam(-2.5f)};
test.AddInput<TypeParam>("X", X_shape, X_data);
test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data);
test.AddAttribute("mode", mode);
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput<TypeParam>("Y", Y_shape, Y_data);
RunTests(test, GetExecutionProviders());
}
Loading