Skip to content

Commit 19b82b4

Browse files
authored
GridSample OP implementation for CPU and CUDA (#8551)
* GridSample OP implementation for CPU and CUDA **Description**: This change contains implementation for torch grid_sample OP. Cuda implementation contains contribution from Muscle Wu. * Use interpolation for out-of-bound points in zero padding mode Out-of-bound points in zeros padding mode changed from constant 0 to interpolation of surrounding pixels. This aligns with Pytorch implementation. A bug in CUDA batch offset calculation is fixed. Custom op exporter type is added. * Fix nearest bug in CPU * Update per CI build finding and review comments * Force float to avoid potential integer T issue * Style update * PR update * Remove c++17 feature from cuda code
1 parent 6f2f472 commit 19b82b4

12 files changed

+735
-5
lines changed

docs/OperatorKernels.md

+2
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ Do not modify directly.*
383383
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
384384
|GatherND|*in* data:**T**<br> *in* indices:**Tind**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
385385
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
386+
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T** = tensor(float)|
386387
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
387388
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
388389
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
@@ -720,6 +721,7 @@ Do not modify directly.*
720721
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
721722
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
722723
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
724+
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T** = tensor(float)|
723725
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
724726
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
725727
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace contrib {
1010

1111
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
1212

13+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample);
1314
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention);
1415
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
1516
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
@@ -177,6 +178,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
177178
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>,
178179

179180
// add more kernels here
181+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample)>,
180182
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention)>,
181183
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
182184
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>,
+275
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#include <cmath>
4+
#include "core/util/math_cpuonly.h"
5+
#include "core/common/common.h"
6+
#include "core/framework/tensor.h"
7+
#include "core/platform/threadpool.h"
8+
9+
#include "grid_sample.h"
10+
11+
namespace onnxruntime {
12+
namespace contrib {
13+
14+
template <typename T>
15+
GridSample<T>::GridSample(const OpKernelInfo& info) : OpKernel(info) {
16+
std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "bilinear") ;
17+
std::string padding_mode_str = info.GetAttrOrDefault<std::string>("padding_mode", "zeros");
18+
align_corners_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("align_corners", 0));
19+
ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic",
20+
"mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
21+
ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection",
22+
"padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
23+
if (mode_str == "bicubic") {
24+
mode_ = Bicubic;
25+
} else if (mode_str == "nearest") {
26+
mode_ = Nearest;
27+
} else {
28+
mode_ = Bilinear;
29+
}
30+
if (padding_mode_str == "reflection") {
31+
padding_mode_ = Reflection;
32+
} else if (padding_mode_str == "border") {
33+
padding_mode_ = Border;
34+
} else {
35+
padding_mode_ = Zeros;
36+
}
37+
}
38+
39+
// Restore normalized location to acutal image location
40+
// When align_corners is true:
41+
// Normalized location (-1, -1) points to the top-left pixel.
42+
// Normalized location (1, 1) points to the bottom-tight pixel.
43+
// When align_corners is false [default]:
44+
// Normalized location (-1, -1) points to the top-left pixel minus half
45+
// pixel in both directions, i.e, (-0.5, -0.5) in acutal image space.
46+
// Normalized location (1, 1) points to the bottom-tight pixel plus half
47+
// pixel in both directions, i.e. (H - 0.5, W - 0.5) in acutal image space.
48+
template <typename T>
49+
T GsDenormalize(T n, int64_t length, bool align_corners) {
50+
T x = {};
51+
if (align_corners) { // align_corners: true => [-1, 1] to [0, length - 1]
52+
x = static_cast<T>((n + 1) / 2.f * (length - 1));
53+
} else { // align_corners: false => [-1, 1] to [-0.5, length - 0.5]
54+
x = static_cast<T>(((n + 1) * length - 1) / 2.f);
55+
}
56+
return x;
57+
}
58+
59+
// Reflect by the near border till within the borders
60+
// Use float for borders to avoid potential issues with integer T
61+
template <typename T>
62+
T GsReflect(T x, float x_min, float x_max) {
63+
float dx = {};
64+
float fx = static_cast<float>(x);
65+
float range = x_max - x_min;
66+
if (fx < x_min) {
67+
dx = x_min - fx;
68+
int n = static_cast<int>(dx / range);
69+
float r = dx - n * range;
70+
if (n % 2 == 0) {
71+
fx = x_min + r;
72+
} else {
73+
fx = x_max - r;
74+
}
75+
} else if (fx > x_max) {
76+
dx = fx - x_max;
77+
int n = static_cast<int>(dx / range);
78+
float r = dx - n * range;
79+
if (n % 2 == 0) {
80+
fx = x_max - r;
81+
} else {
82+
fx = x_min + r;
83+
}
84+
}
85+
// else fallthrough
86+
return static_cast<T>(fx);
87+
}
88+
89+
// Calculate cubic convolution interpolation coefficients
90+
// ROBERT G. KEYS https://ieeexplore.ieee.org/document/1163711
91+
// Use float to avoid potential issues with integer T
92+
void GsGetCubicCoeffs(float x, float coeffs[4]) {
93+
constexpr float cubic_alpha = -0.75f;
94+
x = std::abs(x);
95+
coeffs[0] = ((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha;
96+
coeffs[1] = ((cubic_alpha + 2) * x - (cubic_alpha + 3)) * x * x + 1;
97+
coeffs[2] = ((cubic_alpha + 2) * (1 - x) - (cubic_alpha + 3)) * (1 - x) * (1 - x) + 1;
98+
coeffs[3] = ((cubic_alpha * (2 - x) - 5 * cubic_alpha) * (2 - x) + 8 * cubic_alpha) * (2 - x) - 4 * cubic_alpha;
99+
}
100+
101+
template <typename T>
102+
T GsBicubicInterpolate(T p[4][4], float x, float y) {
103+
float v[4] = {};
104+
float coeffs[4] = {};
105+
GsGetCubicCoeffs(x, coeffs);
106+
for (int64_t i = 0; i < 4; i++) {
107+
v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
108+
}
109+
GsGetCubicCoeffs(y, coeffs);
110+
return static_cast<T>(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]);
111+
}
112+
113+
template <typename T>
114+
T GridSample<T>::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const {
115+
T pixel = {}; // default 0
116+
if (padding_mode_ == Zeros) {
117+
if (c >= 0 && c < W && r >=0 && r < H) {
118+
pixel = image[r * W + c];
119+
}
120+
} else if (padding_mode_ == Border) {
121+
c = std::clamp<int64_t>(c, 0, W - 1);
122+
r = std::clamp<int64_t>(r, 0, H - 1);
123+
pixel = image[r * W + c];
124+
} else { // (padding_mode_ == Reflection)
125+
c = static_cast<int64_t>(GsReflect(static_cast<T>(c), border[0], border[2]));
126+
r = static_cast<int64_t>(GsReflect(static_cast<T>(r), border[1], border[3]));
127+
pixel = image[r * W + c];
128+
}
129+
return pixel;
130+
}
131+
132+
// When grid sampling, padding is applied before interpolation.
133+
// For instance, in bilinear mode and zeros padding-mode, pixel p at actual
134+
// image location (-0.5, -0.5)
135+
// 0 0 <-- Zero padding
136+
// p
137+
// 0 p00 p01 ...
138+
//
139+
// p10 p11 ...
140+
// ...
141+
// would be interpolated as p = p00 / 4
142+
//
143+
template <typename T>
144+
Status GridSample<T>::Compute(OpKernelContext* context) const {
145+
const auto* input = context->Input<Tensor>(0);
146+
const auto* grid = context->Input<Tensor>(1);
147+
const auto& input_dims = input->Shape();
148+
const auto& grid_dims = grid->Shape();
149+
150+
if (input_dims.NumDimensions() != 4 || grid_dims.NumDimensions() != 4) {
151+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported");
152+
}
153+
154+
auto N = input_dims[0];
155+
auto C = input_dims[1];
156+
auto H_in = input_dims[2];
157+
auto W_in = input_dims[3];
158+
auto H_out = grid_dims[1];
159+
auto W_out = grid_dims[2];
160+
ORT_ENFORCE(grid_dims[0] == N, "Grid batch size ", grid_dims[0], " does not match input batch size ", N);
161+
ORT_ENFORCE(grid_dims[3] == 2, "Last dimension of grid: ", grid_dims[3], ", expect 2");
162+
163+
TensorShape Y_shape = {N, C, H_out, W_out};
164+
auto& Y = *context->Output(0, Y_shape);
165+
// Return early if the output tensor is going to be of size 0
166+
if (Y.Shape().Size() == 0) {
167+
return Status::OK();
168+
}
169+
170+
// Force float here to avoid possible issue in integer T case
171+
float x_min = -0.5f;
172+
float x_max = W_in - 0.5f;
173+
float y_min = -0.5f;;
174+
float y_max = H_in - 0.5f;
175+
176+
if (align_corners_) {
177+
x_min = 0.f;
178+
x_max = W_in - 1.f;
179+
y_min = 0.f;
180+
y_max = H_in - 1.f;
181+
}
182+
float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b
183+
184+
concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr;
185+
for (int64_t n = 0; n < N; n++) {
186+
const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2;
187+
concurrency::ThreadPool::TrySimpleParallelFor(
188+
tp, C,
189+
[&](std::ptrdiff_t c) {
190+
const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in);
191+
T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out);
192+
193+
for (int64_t oy = 0; oy < H_out; oy++) {
194+
for (int64_t ox = 0; ox < W_out; ox++) {
195+
const T* gridpoint = grid_data + (oy * W_out + ox) * 2;
196+
T* Y_gridpoint = Y_data + oy * W_out + ox;
197+
auto nx = gridpoint[0]; // normalized location
198+
auto ny = gridpoint[1];
199+
auto x = GsDenormalize<T>(nx, W_in, align_corners_); // actual location
200+
auto y = GsDenormalize<T>(ny, H_in, align_corners_);
201+
202+
if (mode_ == Nearest) {
203+
x = static_cast<T>(std::nearbyintf(static_cast<float>(x)));
204+
y = static_cast<T>(std::nearbyintf(static_cast<float>(y)));
205+
}
206+
207+
if (x < x_min || x > x_max || y < y_min || y > y_max) { // out of bound
208+
if (padding_mode_ == Border) {
209+
// use original border in both align_corner cases
210+
x = std::clamp(x, static_cast<T>(0), static_cast<T>(W_in - 1));
211+
y = std::clamp(y, static_cast<T>(0), static_cast<T>(H_in - 1));
212+
} else if (padding_mode_== Reflection) {
213+
x = GsReflect(x, x_min, x_max);
214+
y = GsReflect(y, y_min, y_max);
215+
}
216+
} // out of bound
217+
218+
if (mode_ == Nearest) {
219+
// x, y are integers in all padding modes
220+
*Y_gridpoint = PixelAtGrid(X_data, static_cast<int64_t>(y), static_cast<int64_t>(x), H_in, W_in, border);
221+
continue;
222+
}
223+
224+
if (mode_ == Bilinear) {
225+
int64_t x1 = static_cast<int64_t>(std::floor(x));
226+
int64_t y1 = static_cast<int64_t>(std::floor(y));
227+
int64_t x2 = x1 + 1;
228+
int64_t y2 = y1 + 1;
229+
230+
T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border);
231+
T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border);
232+
T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border);
233+
T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border);
234+
235+
T dx2 = static_cast<T>(x2) - x;
236+
T dx1 = x - static_cast<T>(x1);
237+
T dy2 = static_cast<T>(y2) - y;
238+
T dy1 = y - static_cast<T>(y1);
239+
*Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
240+
}
241+
if (mode_ == Bicubic) {
242+
int64_t x0 = static_cast<int64_t>(std::floor(x)) - 1; // top-left corner of the bbox
243+
int64_t y0 = static_cast<int64_t>(std::floor(y)) - 1;
244+
T p[4][4] = {}; // [H][W]
245+
for (int64_t h = 0; h < 4; h++) {
246+
for (int64_t w = 0; w < 4; w++) {
247+
p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border);
248+
}
249+
}
250+
T dx = static_cast<T>(x - x0 - 1);
251+
T dy = static_cast<T>(y - y0 - 1);
252+
*Y_gridpoint = GsBicubicInterpolate(p, static_cast<float>(dx), static_cast<float>(dy));
253+
}
254+
}
255+
}
256+
});
257+
}
258+
return Status::OK();
259+
}
260+
261+
#define REGISTER_KERNEL_TYPED(T) \
262+
ONNX_OPERATOR_TYPED_KERNEL_EX( \
263+
GridSample, \
264+
kMSDomain, \
265+
1, \
266+
T, \
267+
kCpuExecutionProvider, \
268+
KernelDefBuilder() \
269+
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
270+
GridSample<T>);
271+
272+
REGISTER_KERNEL_TYPED(float)
273+
274+
} // namespace contrib
275+
} // namespace onnxruntime
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/framework/op_kernel.h"
8+
#include "core/util/math_cpuonly.h"
9+
10+
namespace onnxruntime {
11+
namespace contrib{
12+
13+
template <typename T>
14+
class GridSample final : public OpKernel {
15+
public:
16+
explicit GridSample(const OpKernelInfo& info);
17+
Status Compute(OpKernelContext* context) const override;
18+
19+
private:
20+
enum GridSampleInterpolationMode {
21+
Bilinear,
22+
Nearest,
23+
Bicubic
24+
};
25+
26+
enum GridSamplePaddingMode {
27+
Zeros,
28+
Border,
29+
Reflection
30+
};
31+
32+
T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const;
33+
34+
GridSampleInterpolationMode mode_{Bilinear};
35+
GridSamplePaddingMode padding_mode_{Zeros};
36+
bool align_corners_{0};
37+
};
38+
39+
} //namespace contrib
40+
} //namespace onnxruntime

onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using namespace onnxruntime::common;
99
namespace onnxruntime {
1010
namespace contrib {
1111
namespace cuda {
12+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GridSample);
1213
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FastGelu);
1314
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
1415
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu);
@@ -100,6 +101,7 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
100101
Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
101102
static const BuildKernelCreateInfoFn function_table[] = {
102103
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
104+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GridSample)>,
103105
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FastGelu)>,
104106
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
105107
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu)>,

0 commit comments

Comments
 (0)