|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | +#include <cmath> |
| 4 | +#include "core/util/math_cpuonly.h" |
| 5 | +#include "core/common/common.h" |
| 6 | +#include "core/framework/tensor.h" |
| 7 | +#include "core/platform/threadpool.h" |
| 8 | + |
| 9 | +#include "grid_sample.h" |
| 10 | + |
| 11 | +namespace onnxruntime { |
| 12 | +namespace contrib { |
| 13 | + |
| 14 | +template <typename T> |
| 15 | +GridSample<T>::GridSample(const OpKernelInfo& info) : OpKernel(info) { |
| 16 | + std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "bilinear") ; |
| 17 | + std::string padding_mode_str = info.GetAttrOrDefault<std::string>("padding_mode", "zeros"); |
| 18 | + align_corners_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("align_corners", 0)); |
| 19 | + ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic", |
| 20 | + "mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); |
| 21 | + ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection", |
| 22 | + "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); |
| 23 | + if (mode_str == "bicubic") { |
| 24 | + mode_ = Bicubic; |
| 25 | + } else if (mode_str == "nearest") { |
| 26 | + mode_ = Nearest; |
| 27 | + } else { |
| 28 | + mode_ = Bilinear; |
| 29 | + } |
| 30 | + if (padding_mode_str == "reflection") { |
| 31 | + padding_mode_ = Reflection; |
| 32 | + } else if (padding_mode_str == "border") { |
| 33 | + padding_mode_ = Border; |
| 34 | + } else { |
| 35 | + padding_mode_ = Zeros; |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +// Restore normalized location to acutal image location |
| 40 | +// When align_corners is true: |
| 41 | +// Normalized location (-1, -1) points to the top-left pixel. |
| 42 | +// Normalized location (1, 1) points to the bottom-tight pixel. |
| 43 | +// When align_corners is false [default]: |
| 44 | +// Normalized location (-1, -1) points to the top-left pixel minus half |
| 45 | +// pixel in both directions, i.e, (-0.5, -0.5) in acutal image space. |
| 46 | +// Normalized location (1, 1) points to the bottom-tight pixel plus half |
| 47 | +// pixel in both directions, i.e. (H - 0.5, W - 0.5) in acutal image space. |
| 48 | +template <typename T> |
| 49 | +T GsDenormalize(T n, int64_t length, bool align_corners) { |
| 50 | + T x = {}; |
| 51 | + if (align_corners) { // align_corners: true => [-1, 1] to [0, length - 1] |
| 52 | + x = static_cast<T>((n + 1) / 2.f * (length - 1)); |
| 53 | + } else { // align_corners: false => [-1, 1] to [-0.5, length - 0.5] |
| 54 | + x = static_cast<T>(((n + 1) * length - 1) / 2.f); |
| 55 | + } |
| 56 | + return x; |
| 57 | +} |
| 58 | + |
| 59 | +// Reflect by the near border till within the borders |
| 60 | +// Use float for borders to avoid potential issues with integer T |
| 61 | +template <typename T> |
| 62 | +T GsReflect(T x, float x_min, float x_max) { |
| 63 | + float dx = {}; |
| 64 | + float fx = static_cast<float>(x); |
| 65 | + float range = x_max - x_min; |
| 66 | + if (fx < x_min) { |
| 67 | + dx = x_min - fx; |
| 68 | + int n = static_cast<int>(dx / range); |
| 69 | + float r = dx - n * range; |
| 70 | + if (n % 2 == 0) { |
| 71 | + fx = x_min + r; |
| 72 | + } else { |
| 73 | + fx = x_max - r; |
| 74 | + } |
| 75 | + } else if (fx > x_max) { |
| 76 | + dx = fx - x_max; |
| 77 | + int n = static_cast<int>(dx / range); |
| 78 | + float r = dx - n * range; |
| 79 | + if (n % 2 == 0) { |
| 80 | + fx = x_max - r; |
| 81 | + } else { |
| 82 | + fx = x_min + r; |
| 83 | + } |
| 84 | + } |
| 85 | + // else fallthrough |
| 86 | + return static_cast<T>(fx); |
| 87 | +} |
| 88 | + |
| 89 | +// Calculate cubic convolution interpolation coefficients |
| 90 | +// ROBERT G. KEYS https://ieeexplore.ieee.org/document/1163711 |
| 91 | +// Use float to avoid potential issues with integer T |
| 92 | +void GsGetCubicCoeffs(float x, float coeffs[4]) { |
| 93 | + constexpr float cubic_alpha = -0.75f; |
| 94 | + x = std::abs(x); |
| 95 | + coeffs[0] = ((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha; |
| 96 | + coeffs[1] = ((cubic_alpha + 2) * x - (cubic_alpha + 3)) * x * x + 1; |
| 97 | + coeffs[2] = ((cubic_alpha + 2) * (1 - x) - (cubic_alpha + 3)) * (1 - x) * (1 - x) + 1; |
| 98 | + coeffs[3] = ((cubic_alpha * (2 - x) - 5 * cubic_alpha) * (2 - x) + 8 * cubic_alpha) * (2 - x) - 4 * cubic_alpha; |
| 99 | +} |
| 100 | + |
| 101 | +template <typename T> |
| 102 | +T GsBicubicInterpolate(T p[4][4], float x, float y) { |
| 103 | + float v[4] = {}; |
| 104 | + float coeffs[4] = {}; |
| 105 | + GsGetCubicCoeffs(x, coeffs); |
| 106 | + for (int64_t i = 0; i < 4; i++) { |
| 107 | + v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; |
| 108 | + } |
| 109 | + GsGetCubicCoeffs(y, coeffs); |
| 110 | + return static_cast<T>(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]); |
| 111 | +} |
| 112 | + |
| 113 | +template <typename T> |
| 114 | +T GridSample<T>::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const { |
| 115 | + T pixel = {}; // default 0 |
| 116 | + if (padding_mode_ == Zeros) { |
| 117 | + if (c >= 0 && c < W && r >=0 && r < H) { |
| 118 | + pixel = image[r * W + c]; |
| 119 | + } |
| 120 | + } else if (padding_mode_ == Border) { |
| 121 | + c = std::clamp<int64_t>(c, 0, W - 1); |
| 122 | + r = std::clamp<int64_t>(r, 0, H - 1); |
| 123 | + pixel = image[r * W + c]; |
| 124 | + } else { // (padding_mode_ == Reflection) |
| 125 | + c = static_cast<int64_t>(GsReflect(static_cast<T>(c), border[0], border[2])); |
| 126 | + r = static_cast<int64_t>(GsReflect(static_cast<T>(r), border[1], border[3])); |
| 127 | + pixel = image[r * W + c]; |
| 128 | + } |
| 129 | + return pixel; |
| 130 | +} |
| 131 | + |
| 132 | +// When grid sampling, padding is applied before interpolation. |
| 133 | +// For instance, in bilinear mode and zeros padding-mode, pixel p at actual |
| 134 | +// image location (-0.5, -0.5) |
| 135 | +// 0 0 <-- Zero padding |
| 136 | +// p |
| 137 | +// 0 p00 p01 ... |
| 138 | +// |
| 139 | +// p10 p11 ... |
| 140 | +// ... |
| 141 | +// would be interpolated as p = p00 / 4 |
| 142 | +// |
| 143 | +template <typename T> |
| 144 | +Status GridSample<T>::Compute(OpKernelContext* context) const { |
| 145 | + const auto* input = context->Input<Tensor>(0); |
| 146 | + const auto* grid = context->Input<Tensor>(1); |
| 147 | + const auto& input_dims = input->Shape(); |
| 148 | + const auto& grid_dims = grid->Shape(); |
| 149 | + |
| 150 | + if (input_dims.NumDimensions() != 4 || grid_dims.NumDimensions() != 4) { |
| 151 | + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported"); |
| 152 | + } |
| 153 | + |
| 154 | + auto N = input_dims[0]; |
| 155 | + auto C = input_dims[1]; |
| 156 | + auto H_in = input_dims[2]; |
| 157 | + auto W_in = input_dims[3]; |
| 158 | + auto H_out = grid_dims[1]; |
| 159 | + auto W_out = grid_dims[2]; |
| 160 | + ORT_ENFORCE(grid_dims[0] == N, "Grid batch size ", grid_dims[0], " does not match input batch size ", N); |
| 161 | + ORT_ENFORCE(grid_dims[3] == 2, "Last dimension of grid: ", grid_dims[3], ", expect 2"); |
| 162 | + |
| 163 | + TensorShape Y_shape = {N, C, H_out, W_out}; |
| 164 | + auto& Y = *context->Output(0, Y_shape); |
| 165 | + // Return early if the output tensor is going to be of size 0 |
| 166 | + if (Y.Shape().Size() == 0) { |
| 167 | + return Status::OK(); |
| 168 | + } |
| 169 | + |
| 170 | + // Force float here to avoid possible issue in integer T case |
| 171 | + float x_min = -0.5f; |
| 172 | + float x_max = W_in - 0.5f; |
| 173 | + float y_min = -0.5f;; |
| 174 | + float y_max = H_in - 0.5f; |
| 175 | + |
| 176 | + if (align_corners_) { |
| 177 | + x_min = 0.f; |
| 178 | + x_max = W_in - 1.f; |
| 179 | + y_min = 0.f; |
| 180 | + y_max = H_in - 1.f; |
| 181 | + } |
| 182 | + float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b |
| 183 | + |
| 184 | + concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; |
| 185 | + for (int64_t n = 0; n < N; n++) { |
| 186 | + const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2; |
| 187 | + concurrency::ThreadPool::TrySimpleParallelFor( |
| 188 | + tp, C, |
| 189 | + [&](std::ptrdiff_t c) { |
| 190 | + const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in); |
| 191 | + T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out); |
| 192 | + |
| 193 | + for (int64_t oy = 0; oy < H_out; oy++) { |
| 194 | + for (int64_t ox = 0; ox < W_out; ox++) { |
| 195 | + const T* gridpoint = grid_data + (oy * W_out + ox) * 2; |
| 196 | + T* Y_gridpoint = Y_data + oy * W_out + ox; |
| 197 | + auto nx = gridpoint[0]; // normalized location |
| 198 | + auto ny = gridpoint[1]; |
| 199 | + auto x = GsDenormalize<T>(nx, W_in, align_corners_); // actual location |
| 200 | + auto y = GsDenormalize<T>(ny, H_in, align_corners_); |
| 201 | + |
| 202 | + if (mode_ == Nearest) { |
| 203 | + x = static_cast<T>(std::nearbyintf(static_cast<float>(x))); |
| 204 | + y = static_cast<T>(std::nearbyintf(static_cast<float>(y))); |
| 205 | + } |
| 206 | + |
| 207 | + if (x < x_min || x > x_max || y < y_min || y > y_max) { // out of bound |
| 208 | + if (padding_mode_ == Border) { |
| 209 | + // use original border in both align_corner cases |
| 210 | + x = std::clamp(x, static_cast<T>(0), static_cast<T>(W_in - 1)); |
| 211 | + y = std::clamp(y, static_cast<T>(0), static_cast<T>(H_in - 1)); |
| 212 | + } else if (padding_mode_== Reflection) { |
| 213 | + x = GsReflect(x, x_min, x_max); |
| 214 | + y = GsReflect(y, y_min, y_max); |
| 215 | + } |
| 216 | + } // out of bound |
| 217 | + |
| 218 | + if (mode_ == Nearest) { |
| 219 | + // x, y are integers in all padding modes |
| 220 | + *Y_gridpoint = PixelAtGrid(X_data, static_cast<int64_t>(y), static_cast<int64_t>(x), H_in, W_in, border); |
| 221 | + continue; |
| 222 | + } |
| 223 | + |
| 224 | + if (mode_ == Bilinear) { |
| 225 | + int64_t x1 = static_cast<int64_t>(std::floor(x)); |
| 226 | + int64_t y1 = static_cast<int64_t>(std::floor(y)); |
| 227 | + int64_t x2 = x1 + 1; |
| 228 | + int64_t y2 = y1 + 1; |
| 229 | + |
| 230 | + T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border); |
| 231 | + T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border); |
| 232 | + T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border); |
| 233 | + T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border); |
| 234 | + |
| 235 | + T dx2 = static_cast<T>(x2) - x; |
| 236 | + T dx1 = x - static_cast<T>(x1); |
| 237 | + T dy2 = static_cast<T>(y2) - y; |
| 238 | + T dy1 = y - static_cast<T>(y1); |
| 239 | + *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); |
| 240 | + } |
| 241 | + if (mode_ == Bicubic) { |
| 242 | + int64_t x0 = static_cast<int64_t>(std::floor(x)) - 1; // top-left corner of the bbox |
| 243 | + int64_t y0 = static_cast<int64_t>(std::floor(y)) - 1; |
| 244 | + T p[4][4] = {}; // [H][W] |
| 245 | + for (int64_t h = 0; h < 4; h++) { |
| 246 | + for (int64_t w = 0; w < 4; w++) { |
| 247 | + p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border); |
| 248 | + } |
| 249 | + } |
| 250 | + T dx = static_cast<T>(x - x0 - 1); |
| 251 | + T dy = static_cast<T>(y - y0 - 1); |
| 252 | + *Y_gridpoint = GsBicubicInterpolate(p, static_cast<float>(dx), static_cast<float>(dy)); |
| 253 | + } |
| 254 | + } |
| 255 | + } |
| 256 | + }); |
| 257 | + } |
| 258 | + return Status::OK(); |
| 259 | +} |
| 260 | + |
| 261 | +#define REGISTER_KERNEL_TYPED(T) \ |
| 262 | + ONNX_OPERATOR_TYPED_KERNEL_EX( \ |
| 263 | + GridSample, \ |
| 264 | + kMSDomain, \ |
| 265 | + 1, \ |
| 266 | + T, \ |
| 267 | + kCpuExecutionProvider, \ |
| 268 | + KernelDefBuilder() \ |
| 269 | + .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ |
| 270 | + GridSample<T>); |
| 271 | + |
| 272 | +REGISTER_KERNEL_TYPED(float) |
| 273 | + |
| 274 | +} // namespace contrib |
| 275 | +} // namespace onnxruntime |
0 commit comments