Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for NHWC GridSample in the CUDA EP and enable grid_sample_test for all EPs #19562

Merged
merged 9 commits into from
Feb 23, 2024
21 changes: 14 additions & 7 deletions onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_
T pixel = 0.0f;

auto PixelOffset = [bIdx, cIdx, N, C, H, W](int64_t x, int64_t y) -> int64_t {
return Layout == LAYOUT_NCHW ? (bIdx * C * H * W + cIdx * H * W + y * W + x) : (bIdx * H * W * C + y * W * C + x * C + cIdx);
return Layout == LAYOUT_NCHW
? (bIdx * C * H * W + cIdx * H * W + y * W + x)
: (bIdx * H * W * C + y * W * C + x * C + cIdx);
};

if (padding_mode == 0) { // zeros
Expand Down Expand Up @@ -168,8 +170,8 @@ __global__ void _GridSampleKernel(
grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound
if (padding_mode == 1) { // border
// Clamping must not be done here, see #10607
//grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
//grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
// grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
// grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
} else if (padding_mode == 2) { // reflection
grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max);
grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max);
Expand Down Expand Up @@ -207,7 +209,8 @@ __global__ void _GridSampleKernel(
if (mode == 1) { // nearest
int x_n = grid_x_imgSpace;
int y_n = grid_y_imgSpace;
output_data[outIdx] = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
output_data[outIdx] =
PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
return;
}
if (mode == 2) { // bicubic
Expand All @@ -216,7 +219,8 @@ __global__ void _GridSampleKernel(
T p[4][4] = {}; // [H][W]
for (int64_t h = 0; h < 4; h++) {
for (int64_t w = 0; w < 4; w++) {
p[h][w] = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
p[h][w] =
PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
}
}
T dx = grid_x_imgSpace - x0 - 1;
Expand All @@ -239,9 +243,12 @@ void GridSampleImpl(
T* output_data) {
using Ch = Channels<IsNHWC>;

int blocksPerGrid = (int)(ceil(static_cast<T>(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock));
int blocksPerGrid = static_cast<int>
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
ceil(static_cast<T>(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock));
_GridSampleKernel<T, IsNHWC><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_data, grid_data, mode, padding_mode, align_corners, dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], H_out, W_out, output_data);
input_data, grid_data, mode, padding_mode, align_corners,
dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W],
H_out, W_out, output_data);
}

#define SPECIALIZED_IMPL(T, IsNHWC) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a
}

#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
// TODO generate list from registered kernels using nhwc domain
// TODO(mtavenrath) generate list from registered kernels using nhwc domain
const std::unordered_set<std::string_view>& GetCUDALayoutSensitiveOps() {
static std::unordered_set<std::string_view> cuda_nhwc_ops = []() {
return std::unordered_set<std::string_view>{
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace onnxruntime {
namespace test {

std::vector<std::unique_ptr<IExecutionProvider>> GetExecutionProviders(int opset_version) {
ORT_UNUSED_PARAMETER(opset_version);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;

execution_providers.emplace_back(DefaultCpuExecutionProvider());
Expand Down