From 28e9ea56a527139e734d94e9bbdd2ab38444b47c Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 28 Jan 2026 19:28:53 -0800 Subject: [PATCH 1/5] Initial commit --- .../core/providers/cuda/cuda_execution_provider.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) mode change 100644 => 100755 onnxruntime/core/providers/cuda/cuda_execution_provider.cc diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc old mode 100644 new mode 100755 index eb29e4edbf897..f2d3f28ba2bf6 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1421,7 +1421,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 19, float, GridSample); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1510,6 +1510,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, 21, float, GridSample); // Opset 21. // TODO(fajin): support other quantized types @@ -1583,6 +1584,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GridSample); // Opset 23. class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention); @@ -2485,7 +2487,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2582,6 +2584,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 21 // TODO(fajin): support other quantized types @@ -2654,6 +2657,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 BuildKernelCreateInfo, From c117034394f2a7a9bc1e4027d6174e9003fd3395 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 28 Jan 2026 19:53:55 -0800 Subject: [PATCH 2/5] Commit --- .../core/providers/cuda/cuda_nhwc_kernels.cc | 9 +- .../core/providers/cuda/tensor/grid_sample.cc | 159 +++++++++--- .../core/providers/cuda/tensor/grid_sample.h | 1 + .../providers/cuda/tensor/grid_sample_impl.cu | 237 +++++++++++++++++- .../providers/cuda/tensor/grid_sample_impl.h | 14 ++ .../providers/cpu/tensor/grid_sample_test.cc | 114 +++++---- 6 files changed, 441 insertions(+), 93 deletions(-) mode change 100644 => 100755 onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc mode change 100644 => 100755 onnxruntime/core/providers/cuda/tensor/grid_sample.cc mode change 100644 => 100755 onnxruntime/core/providers/cuda/tensor/grid_sample.h mode change 100644 => 100755 onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu mode change 100644 => 100755 onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h mode change 100644 => 100755 onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc old mode 100644 new mode 100755 index e8995a0ec623a..8239a8ac252e6 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -164,12 +164,17 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { #ifndef DISABLE_CONTRIB_OPS namespace onnxruntime::contrib::cuda { -class CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample); +class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(16, 19, float, GridSample); +class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(20, 21, float, GridSample); +class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, GridSample); onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn nhwc_function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : nhwc_function_table) { diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc old mode 100644 new mode 100755 index 1884ae4689899..91736d7b48f7e --- a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc @@ -21,28 +21,66 @@ namespace cuda { .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ onnxruntime::contrib::cuda::GridSample); +#define REGISTER_KERNEL_VERSIONED_TYPED(T, FROM_VERSION, TO_VERSION, LAYOUT, DOMAIN) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + GridSample, \ + DOMAIN, \ + FROM_VERSION, \ + TO_VERSION, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + onnxruntime::contrib::cuda::GridSample); + REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) #ifdef ENABLE_CUDA_NHWC_OPS -REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain) +// Op was introduced in opset 16 +REGISTER_KERNEL_VERSIONED_TYPED(float, 16, 19, LAYOUT_NHWC, kMSInternalNHWCDomain) + +// Op was modified to support multiple spatial dimensions in opset 20 +REGISTER_KERNEL_VERSIONED_TYPED(float, 20, 21, LAYOUT_NHWC, kMSInternalNHWCDomain) + +// Op spec introduced BFloat16 support in opset 22 +REGISTER_KERNEL_TYPED(float, 22, LAYOUT_NHWC, kMSInternalNHWCDomain) #endif template GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { + opset_start_version_ = info.node().SinceVersion(); + std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); - ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic", - "mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); - ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection", - "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); - if (mode_str == "bicubic") { - mode_i_ = 2; - } else if (mode_str == "nearest") { - mode_i_ = 1; + + if (opset_start_version_ >= 20) { + std::string mode_str = info.GetAttrOrDefault("mode", "linear"); + if (mode_str == "cubic") { + mode_i_ = 2; + } else if (mode_str == "nearest") { + mode_i_ = 1; + } else if (mode_str == "linear") { + mode_i_ = 0; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect linear, nearest or cubic"); + } } else { - mode_i_ = 0; + std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); + if (mode_str == "bicubic") { + mode_i_ = 2; + } else if (mode_str == "nearest") { + mode_i_ = 1; + } else if (mode_str == "bilinear") { + mode_i_ = 0; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); + } } + + ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection", + "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); if (padding_mode_str == "reflection") { padding_mode_i_ = 2; } else if (padding_mode_str == "border") { @@ -59,20 +97,52 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { const Tensor* Grid = context->Input(1); const auto& dims_grid = Grid->Shape().GetDims(); - if (dims_input.size() != 4 || dims_grid.size() != 4) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported"); + if (dims_input.size() != dims_grid.size()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Input and grid must have the same number of dimensions"); + } + + if (opset_start_version_ < 20 && dims_input.size() != 4) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Opset 16-19 versions of this op only supports 4-D input tensors"); + } + + if (dims_input[0] != dims_grid[0]) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Grid batch size does not match input batch size "); + } + + if (dims_input.size() != 4 && dims_input.size() != 5) { + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Only 4-D and 5-D input tensors are supported"); + } + + + if (dims_input.size() == 5 && mode_i_ == 2) { + // This is common for CPU and CUDA to not support Cubic mode for 5D input + // So it won't break CUDA users who were previously dropping down to CPU version of the op. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Cubic mode is only supported in 4-D cases."); + } + + if ((dims_input.size() == 4 && dims_grid[3] != 2) || (dims_input.size() == 5 && dims_grid[4] != 3)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Last dimension of grid input must match the number of " + "spatial dimensions in the input (2 for 2D, 3 for 3D)."); } - ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]); - ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2"); using Ch = Channels; - TensorShapeVector dims_output(4); - dims_output[Ch::N] = dims_input[Ch::N]; - dims_output[Ch::C] = dims_input[Ch::C]; - dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; - dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; + TensorShapeVector dims_output(dims_input.size()); + if (dims_input.size() == 4) { + dims_output[Ch::N] = dims_input[Ch::N]; + dims_output[Ch::C] = dims_input[Ch::C]; + dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; + dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; + } else { + // 5D NCHW layout: N, C, D, H, W + dims_output[0] = dims_input[0]; + dims_output[1] = dims_input[1]; + dims_output[2] = dims_grid[1 /* Grid::D */]; + dims_output[3] = dims_grid[2 /* Grid::H */]; + dims_output[4] = dims_grid[3 /* Grid::W */]; + } Tensor* Y = context->Output(0, dims_output); + // Return early if the output tensor is going to be of size 0 if (Y->Shape().Size() == 0) { return Status::OK(); @@ -80,23 +150,50 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; CudaT* Y_data = reinterpret_cast(Y->MutableData()); - GridSampleImpl( - Stream(context), - reinterpret_cast(X->Data()), - reinterpret_cast(Grid->Data()), - mode_i_, - padding_mode_i_, - align_corners_, - dims_input.data(), - dims_grid[1], - dims_grid[2], - Y_data); + + if (dims_input.size() == 4) { + // sample 2d + GridSampleImpl( + Stream(context), + reinterpret_cast(X->Data()), + reinterpret_cast(Grid->Data()), + mode_i_, + padding_mode_i_, + align_corners_, + dims_input.data(), + dims_grid[1], + dims_grid[2], + Y_data); + } else { + // sample 3d + GridSampleImpl3D( + Stream(context), + reinterpret_cast(X->Data()), + reinterpret_cast(Grid->Data()), + mode_i_, + padding_mode_i_, + align_corners_, + dims_input.data(), + dims_grid[1], + dims_grid[2], + dims_grid[3], + Y_data); + } + + return Status::OK(); } } // namespace cuda } // namespace contrib namespace cuda { -REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) +// Op was introduced in opset 16 +REGISTER_KERNEL_VERSIONED_TYPED(float, 16, 19, LAYOUT_NCHW, kOnnxDomain) + +// Op was modified to support multiple spatial dimensions in opset 20 +REGISTER_KERNEL_VERSIONED_TYPED(float, 20, 21, LAYOUT_NCHW, kOnnxDomain) + +// Op spec introduced BFloat16 support in opset 22 +REGISTER_KERNEL_TYPED(float, 22, LAYOUT_NCHW, kOnnxDomain) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample.h b/onnxruntime/core/providers/cuda/tensor/grid_sample.h old mode 100644 new mode 100755 index 16581bfe77482..74be67c921aae --- a/onnxruntime/core/providers/cuda/tensor/grid_sample.h +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample.h @@ -22,6 +22,7 @@ class GridSample final : public CudaKernel { int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic int64_t padding_mode_i_; // 0:'zeros', 1: 'border', 2:'reflection' int64_t align_corners_; + int opset_start_version_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu old mode 100644 new mode 100755 index b5b4a84576bbe..9bda988e93552 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu @@ -142,6 +142,10 @@ __global__ void _GridSampleKernel( } int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; + // ONNX spec guideline about ordering of grid indices: + // Following computer vision convention, the coordinates in the length-r location vector + // are listed from the innermost tensor dimension to the outermost, the opposite of regular + // tensor indexing. T grid_X = grid_data[grid_idx * 2 + 0]; T grid_Y = grid_data[grid_idx * 2 + 1]; int outIdx = idx; @@ -159,9 +163,9 @@ __global__ void _GridSampleKernel( if (align_corners) { x_min = 0.0f; - x_max = W_in - 1.0; + x_max = float(W_in) - 1.0f; y_min = 0.0f; - y_max = H_in - 1.0f; + y_max = float(H_in) - 1.0f; } float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || @@ -257,6 +261,235 @@ void GridSampleImpl( SPECIALIZED_IMPL(float, false) // NCHW SPECIALIZED_IMPL(float, true) // NHWC +template +__device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t z, int64_t y, int64_t x, + int64_t padding_mode, int64_t N, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) { + T pixel = 0.0f; + + auto PixelOffset3D = [bIdx, cIdx, C, D, H, W](int64_t z, int64_t y, int64_t x) -> int64_t { + return Layout == LAYOUT_NCHW + ? (bIdx * C * D * H * W + cIdx * D * H * W + z * H * W + y * W + x) + : 0; // Placeholder for NHWC layout in 3D, to be implemented as needed + }; + + if (padding_mode == 0) { // zeros + if (z >= 0 && z < D && y >= 0 && y < H && x >= 0 && x < W) { + pixel = input_data[PixelOffset3D(z, y, x)]; + } + } else if (padding_mode == 1) { // border + z = max((int64_t)0, min((int64_t)D - 1, (int64_t)z)); + y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y)); + x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x)); + + pixel = input_data[PixelOffset3D(z, y, x)]; + } else { // Reflection + z = (int64_t)GsReflect(z, border[0], border[3]); + y = (int64_t)GsReflect(y, border[1], border[4]); + x = (int64_t)GsReflect(x, border[2], border[5]); + + pixel = input_data[PixelOffset3D(z, y, x)]; + } + return pixel; +} + +// Currently only supports NCHW layout for 3D grid sampling +// TODO(hasesh): Implement NHWC layout support if needed +template +__global__ void _GridSampleKernel3D( + const T* __restrict__ input_data, + const T* __restrict__ grid_data, + const int64_t mode, + const int64_t padding_mode, + const int64_t align_corners, + const int64_t N, + const int64_t C, + const int64_t D_in, + const int64_t H_in, + const int64_t W_in, + const int64_t D_out, + const int64_t H_out, + const int64_t W_out, + T* __restrict__ output_data) { + + if constexpr (Layout == LAYOUT_NHWC) { + // NHWC layout for 3D is not implemented + return; + } + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * D_out * H_out * W_out); + + // extract batch index, channel index, y index, x index, z index for current thread + int BIdx, cIdx, zIdx, yIdx, xIdx; + + BIdx = idx / (C * D_out * H_out * W_out); + int tmpBCnt = BIdx * (C * D_out * H_out * W_out); + + cIdx = (idx - tmpBCnt) / (D_out * H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (D_out * H_out * W_out); + + zIdx = (idx - tmpCCnt) / (H_out * W_out); + int tmpDCnt = tmpCCnt + zIdx * (H_out * W_out); + + yIdx = (idx - tmpDCnt) / (W_out); + int tmpHCnt = tmpDCnt + yIdx * W_out; + + xIdx = (idx - tmpHCnt); + + int grid_idx = BIdx * D_out * H_out * W_out + zIdx * H_out * W_out + yIdx * W_out + xIdx; + + // ONNX spec guideline about ordering of grid indices: + // Following computer vision convention, the coordinates in the length-r location vector + // are listed from the innermost tensor dimension to the outermost, the opposite of regular + // tensor indexing. + T grid_X = grid_data[grid_idx * 3 + 0]; + T grid_Y = grid_data[grid_idx * 3 + 1]; + T grid_Z = grid_data[grid_idx * 3 + 2]; + + int outIdx = idx; + + T grid_x_volSpace = GsDenormalize(grid_X, W_in, align_corners == 1); + T grid_y_volSpace = GsDenormalize(grid_Y, H_in, align_corners == 1); + T grid_z_volSpace = GsDenormalize(grid_Z, D_in, align_corners == 1); + + if (mode == 1) { // nearest + grid_x_volSpace = nearbyint(grid_x_volSpace); + grid_y_volSpace = nearbyint(grid_y_volSpace); + grid_z_volSpace = nearbyint(grid_z_volSpace); + } + + float z_min = -0.5f; + float z_max = D_in - 0.5f; + float y_min = -0.5f; + float y_max = H_in - 0.5f; + float x_min = -0.5f; + float x_max = W_in - 0.5f; + + if (align_corners) { + z_min = 0.0f; + z_max = float(D_in) - 1.0f; + y_min = 0.0f; + y_max = float(H_in) - 1.0f; + x_min = 0.0f; + x_max = float(W_in) - 1.0f; + } + + float border[] = {z_min, y_min, x_min, z_max, y_max, x_max}; // zmin,ymin,xmin,zmax,ymax,xmax + if (grid_z_volSpace < z_min || grid_z_volSpace > z_max || + grid_y_volSpace < y_min || grid_y_volSpace > y_max || + grid_x_volSpace < x_min || grid_x_volSpace > x_max) { // out of bound + if (padding_mode == 1) { // border + // Clamping must not be done here, see #10607 + // grid_z_volSpace = max(0.0f, min(grid_z_volSpace, D_in - 1.0f)); + // grid_y_volSpace = max(0.0f, min(grid_y_volSpace, H_in - 1.0f)); + // grid_x_volSpace = max(0.0f, min(grid_x_volSpace, W_in - 1.0f)); + } else if (padding_mode == 2) { // reflection + grid_z_volSpace = GsReflect(grid_z_volSpace, z_min, z_max); + grid_y_volSpace = GsReflect(grid_y_volSpace, y_min, y_max); + grid_x_volSpace = GsReflect(grid_x_volSpace, x_min, x_max); + } + } + + if (mode == 0) { // bilinear + int z1 = floor(grid_z_volSpace); + int y1 = floor(grid_y_volSpace); + int x1 = floor(grid_x_volSpace); + int z2 = z1 + 1; + int y2 = y1 + 1; + int x2 = x1 + 1; + + // Weights + T w_lt_front = 0.0f; + T w_rt_front = 0.0f; + T w_lb_front = 0.0f; + T w_rb_front = 0.0f; + + T w_lt_back = 0.0f; + T w_rt_back = 0.0f; + T w_lb_back = 0.0f; + T w_rb_back = 0.0f; + + // Assign weight values + T w_back = grid_z_volSpace - z1; + T w_front = 1.0f - w_back; + T w_b = grid_y_volSpace - y1; + T w_t = 1.0f - w_b; + T w_r = grid_x_volSpace - x1; + T w_l = 1.0f - w_r; + + w_lt_front = w_t * w_l * w_front; + w_rt_front = w_t * w_r * w_front; + w_lb_front = w_b * w_l * w_front; + w_rb_front = w_b * w_r * w_front; + + w_lt_back = w_t * w_l * w_back; + w_rt_back = w_t * w_r * w_back; + w_lb_back = w_b * w_l * w_back; + w_rb_back = w_b * w_r * w_back; + + T lt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, N, C,D_in, H_in, W_in, border); + T rt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border); + T lb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border); + T rb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border); + + T lt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border); + T rt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border); + T lb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border); + T rb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border); + + T interpoV = w_lt_front * lt_front_v + w_rt_front * rt_front_v + w_lb_front * lb_front_v + w_rb_front * rb_front_v + + w_lt_back * lt_back_v + w_rt_back * rt_back_v + w_lb_back * lb_back_v + w_rb_back * rb_back_v; + + output_data[outIdx] = interpoV; + return; + } + if (mode == 1) { // nearest + int x_n = grid_x_volSpace; + int y_n = grid_y_volSpace; + int z_n = grid_z_volSpace; + + output_data[outIdx] = + PixelAtGrid3D(input_data, BIdx, cIdx, z_n, y_n, x_n, padding_mode, N, C, D_in, H_in, W_in, border); + return; + } + if (mode == 2) { // cubic + // Not implemented for 3D input. But will not reach here per input validation + } +} + +template +void GridSampleImpl3D( + cudaStream_t stream, + const T* input_data, + const T* grid_data, + const int64_t mode, + const int64_t padding_mode, + const int64_t align_corners, + const int64_t dims[5], + const int64_t D_out, + const int64_t H_out, + const int64_t W_out, + T* output_data) { + + // Currently only NCHW layout is supported for 3D grid sampling + assert(IsNHWC == false); + + int blocksPerGrid = static_cast( + ceil(static_cast(dims[0] * dims[1] * D_out * H_out * W_out) / GridDim::maxThreadsPerBlock)); + _GridSampleKernel3D<<>>( + input_data, grid_data, mode, padding_mode, align_corners, + dims[0], dims[1], dims[2], dims[3], dims[4], + D_out, H_out, W_out, output_data); +} + +template void GridSampleImpl3D(cudaStream_t stream, const float* input_data, const float* grid_data, + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, + const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out, + float* output_data); + +template void GridSampleImpl3D(cudaStream_t stream, const float* input_data, const float* grid_data, + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, + const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out, + float* output_data); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h old mode 100644 new mode 100755 index 62cd66a48fa84..156de20ed11e3 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h @@ -21,6 +21,20 @@ void GridSampleImpl( const int64_t W_out, T* output_data); +template +void GridSampleImpl3D( + cudaStream_t stream, + const T* input_data, + const T* grid_data, + const int64_t mode, + const int64_t padding_mode, + const int64_t align_corners, + const int64_t dims_input[5], + const int64_t D_out, + const int64_t H_out, + const int64_t W_out, + T* output_data); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc old mode 100644 new mode 100755 index 05cfb5c13d689..f0114b442ae12 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -15,12 +15,10 @@ std::vector> GetExecutionProviders(int opset execution_providers.emplace_back(DefaultCpuExecutionProvider()); #ifdef USE_CUDA - if (opset_version < 20) { execution_providers.emplace_back(DefaultCudaExecutionProvider()); #ifdef ENABLE_CUDA_NHWC_OPS execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); #endif - } #endif #if defined(USE_COREML) @@ -64,7 +62,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { @@ -84,7 +82,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { @@ -104,7 +102,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { @@ -124,7 +122,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { @@ -144,7 +142,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { @@ -164,11 +162,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_co test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { + OpTester test("GridSample", 20); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; @@ -184,11 +182,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { + OpTester test("GridSample", 20); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; @@ -204,11 +202,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { + OpTester test("GridSample", 20); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 1; @@ -224,11 +222,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { + OpTester test("GridSample", 20); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 0; @@ -244,11 +242,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_reflection_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; @@ -264,11 +262,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corn test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_reflection_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; @@ -284,11 +282,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_c test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_zeros_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; @@ -304,11 +302,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_zeros_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; @@ -324,11 +322,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_border_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 1; @@ -344,11 +342,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_border_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 0; @@ -364,11 +362,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_reflection_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; @@ -384,11 +382,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { - OpTester test("GridSample", 16); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_reflection_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; @@ -404,7 +402,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_co test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(16)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { @@ -707,8 +705,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corner RunTests(test, GetExecutionProviders(20)); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_zeros_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; @@ -727,8 +725,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corner RunTests(test, GetExecutionProviders(20)); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_align_corners) { + OpTester test("GridSample", 22); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; @@ -747,8 +745,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) RunTests(test, GetExecutionProviders(20)); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_align_corners) { + OpTester test("GridSample", 22); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; @@ -767,8 +765,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) RunTests(test, GetExecutionProviders(20)); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_border_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; @@ -787,8 +785,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corne RunTests(test, GetExecutionProviders(20)); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; @@ -827,8 +825,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corn RunTests(test, GetExecutionProviders(20)); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_align_corners) { + OpTester test("GridSample", 22); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; @@ -867,8 +865,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_c RunTests(test, GetExecutionProviders(20)); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_no_align_corners) { + OpTester test("GridSample", 22); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; From b4f76de3c4b3dd2c1cca44f7bdfe295fa2726e94 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 28 Jan 2026 21:22:55 -0800 Subject: [PATCH 3/5] More changes --- .../onnx_transpose_optimization.cc | 2 +- .../core/providers/cuda/tensor/grid_sample.cc | 20 +-- .../providers/cuda/tensor/grid_sample_impl.cu | 67 ++++++---- .../providers/cpu/tensor/grid_sample_test.cc | 116 +++++++++--------- 4 files changed, 114 insertions(+), 91 deletions(-) mode change 100644 => 100755 onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc old mode 100644 new mode 100755 index 6f2538bcde3b1..8fff4f03010da --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -748,7 +748,7 @@ std::vector ChannelLastToFirstPerm(size_t rank) { } std::vector p(rank); - p[0] = 0; + p[0] = 0; // This is usually the batch dimension (hence preserve this position) p[1] = rank - 1; for (size_t i = 2; i < rank; ++i) { p[i] = i - 1; diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc index 91736d7b48f7e..75fa711859193 100755 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc @@ -109,6 +109,11 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Grid batch size does not match input batch size "); } + if ((dims_input.size() == 4 && dims_grid[3] != 2) || (dims_input.size() == 5 && dims_grid[4] != 3)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Last dimension of grid input must match the number of " + "spatial dimensions in the input (2 for 2D, 3 for 3D)."); + } + if (dims_input.size() != 4 && dims_input.size() != 5) { return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Only 4-D and 5-D input tensors are supported"); } @@ -120,11 +125,6 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Cubic mode is only supported in 4-D cases."); } - if ((dims_input.size() == 4 && dims_grid[3] != 2) || (dims_input.size() == 5 && dims_grid[4] != 3)) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Last dimension of grid input must match the number of " - "spatial dimensions in the input (2 for 2D, 3 for 3D)."); - } - using Ch = Channels; TensorShapeVector dims_output(dims_input.size()); @@ -134,12 +134,12 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; } else { - // 5D NCHW layout: N, C, D, H, W + // 5D input - deal with both NCHW and NHWC layouts dims_output[0] = dims_input[0]; - dims_output[1] = dims_input[1]; - dims_output[2] = dims_grid[1 /* Grid::D */]; - dims_output[3] = dims_grid[2 /* Grid::H */]; - dims_output[4] = dims_grid[3 /* Grid::W */]; + dims_output[1] = !IsNHWC ? dims_input[1] : dims_grid[1]; + dims_output[2] = !IsNHWC ? dims_grid[1] : dims_grid[2]; + dims_output[3] = !IsNHWC ? dims_grid[2] : dims_grid[3]; + dims_output[4] = !IsNHWC ? dims_grid[3] : dims_input[4]; } Tensor* Y = context->Output(0, dims_output); diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu index 9bda988e93552..dd29f62e83c08 100755 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu @@ -269,7 +269,7 @@ __device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int auto PixelOffset3D = [bIdx, cIdx, C, D, H, W](int64_t z, int64_t y, int64_t x) -> int64_t { return Layout == LAYOUT_NCHW ? (bIdx * C * D * H * W + cIdx * D * H * W + z * H * W + y * W + x) - : 0; // Placeholder for NHWC layout in 3D, to be implemented as needed + : (bIdx * D * H * W * C + z * H * W * C + y * W * C + x * C + cIdx); }; if (padding_mode == 0) { // zeros @@ -292,8 +292,6 @@ __device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int return pixel; } -// Currently only supports NCHW layout for 3D grid sampling -// TODO(hasesh): Implement NHWC layout support if needed template __global__ void _GridSampleKernel3D( const T* __restrict__ input_data, @@ -311,29 +309,39 @@ __global__ void _GridSampleKernel3D( const int64_t W_out, T* __restrict__ output_data) { - if constexpr (Layout == LAYOUT_NHWC) { - // NHWC layout for 3D is not implemented - return; - } - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * D_out * H_out * W_out); // extract batch index, channel index, y index, x index, z index for current thread int BIdx, cIdx, zIdx, yIdx, xIdx; + if constexpr (Layout == LAYOUT_NCHW) { + BIdx = idx / (C * D_out * H_out * W_out); + int tmpBCnt = BIdx * (C * D_out * H_out * W_out); + + cIdx = (idx - tmpBCnt) / (D_out * H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (D_out * H_out * W_out); + + zIdx = (idx - tmpCCnt) / (H_out * W_out); + int tmpDCnt = tmpCCnt + zIdx * (H_out * W_out); + + yIdx = (idx - tmpDCnt) / (W_out); + int tmpHCnt = tmpDCnt + yIdx * W_out; - BIdx = idx / (C * D_out * H_out * W_out); - int tmpBCnt = BIdx * (C * D_out * H_out * W_out); + xIdx = (idx - tmpHCnt); + } else { // Layout == LAYOUT_NHWC + BIdx = idx / (D_out * H_out * W_out * C); + int tmpBCnt = BIdx * (D_out * H_out * W_out * C); - cIdx = (idx - tmpBCnt) / (D_out * H_out * W_out); - int tmpCCnt = tmpBCnt + cIdx * (D_out * H_out * W_out); + zIdx = (idx - tmpBCnt) / (H_out * W_out * C); + int tmpDCnt = tmpBCnt + zIdx * (H_out * W_out * C); - zIdx = (idx - tmpCCnt) / (H_out * W_out); - int tmpDCnt = tmpCCnt + zIdx * (H_out * W_out); + yIdx = (idx - tmpDCnt) / (W_out * C); + int tmpHCnt = tmpDCnt + yIdx * (W_out * C); - yIdx = (idx - tmpDCnt) / (W_out); - int tmpHCnt = tmpDCnt + yIdx * W_out; + xIdx = (idx - tmpHCnt) / C; + int tmpWCnt = tmpHCnt + xIdx * C; - xIdx = (idx - tmpHCnt); + cIdx = (idx - tmpWCnt); + } int grid_idx = BIdx * D_out * H_out * W_out + zIdx * H_out * W_out + yIdx * W_out + xIdx; @@ -470,14 +478,31 @@ void GridSampleImpl3D( const int64_t W_out, T* output_data) { - // Currently only NCHW layout is supported for 3D grid sampling - assert(IsNHWC == false); + int64_t N = 0; + int64_t C = 0; + int64_t D_in = 0; + int64_t H_in = 0; + int64_t W_in = 0; + + if constexpr (IsNHWC) { + N = dims[0]; + D_in = dims[1]; + H_in = dims[2]; + W_in = dims[3]; + C = dims[4]; + } else { + N = dims[0]; + C = dims[1]; + D_in = dims[2]; + H_in = dims[3]; + W_in = dims[4]; + } int blocksPerGrid = static_cast( - ceil(static_cast(dims[0] * dims[1] * D_out * H_out * W_out) / GridDim::maxThreadsPerBlock)); + ceil(static_cast(N * C * D_out * H_out * W_out) / GridDim::maxThreadsPerBlock)); _GridSampleKernel3D<<>>( input_data, grid_data, mode, padding_mode, align_corners, - dims[0], dims[1], dims[2], dims[3], dims[4], + N, C, D_in, H_in, W_in, D_out, H_out, W_out, output_data); } diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index f0114b442ae12..18dc3a6d07446 100755 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -7,9 +7,7 @@ namespace onnxruntime { namespace test { -std::vector> GetExecutionProviders(int opset_version) { - ORT_UNUSED_PARAMETER(opset_version); - +std::vector> GetExecutionProviders() { std::vector> execution_providers; execution_providers.emplace_back(DefaultCpuExecutionProvider()); @@ -165,8 +163,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_co RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; @@ -185,8 +183,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; @@ -205,8 +203,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corner RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 1; @@ -225,8 +223,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { - OpTester test("GridSample", 20); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 0; @@ -245,8 +243,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corne RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_reflection_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; @@ -265,8 +263,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_reflection_align_corn RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_reflection_no_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; @@ -285,8 +283,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_reflection_no_align_c RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_zeros_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; @@ -305,8 +303,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_zeros_align_corners) { RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_zeros_no_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; @@ -325,8 +323,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_zeros_no_align_corners RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_border_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 1; @@ -345,8 +343,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_border_align_corners) RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_border_no_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 0; @@ -365,8 +363,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_border_no_align_corner RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_reflection_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; @@ -385,8 +383,8 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_reflection_align_corne RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bicubic_reflection_no_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { + OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; @@ -422,7 +420,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { @@ -442,7 +440,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { @@ -462,7 +460,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { @@ -482,7 +480,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { @@ -502,7 +500,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { @@ -522,7 +520,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { @@ -542,7 +540,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { @@ -562,7 +560,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { @@ -582,7 +580,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { @@ -602,7 +600,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { @@ -622,7 +620,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_co test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { @@ -642,7 +640,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_co test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { @@ -662,7 +660,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { @@ -682,7 +680,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { @@ -702,7 +700,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_zeros_no_align_corners) { @@ -722,11 +720,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_zeros_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } -TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_align_corners) { - OpTester test("GridSample", 22); +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { + OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; @@ -742,7 +740,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_align_corners) { @@ -762,7 +760,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_border_no_align_corners) { @@ -782,7 +780,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_border_no_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_no_align_corners) { @@ -802,7 +800,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_no_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { @@ -822,7 +820,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corn test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_align_corners) { @@ -842,7 +840,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_align_corn test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { @@ -862,7 +860,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_c test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_no_align_corners) { @@ -882,7 +880,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_no_align_c test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { @@ -902,7 +900,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { @@ -922,7 +920,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { @@ -942,7 +940,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { @@ -962,7 +960,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corner test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { @@ -982,7 +980,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corne test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { @@ -1002,7 +1000,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_co test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } } // namespace test From aaf3538fbcd7fec066492069247286386c8f75fb Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 30 Jan 2026 18:28:32 -0800 Subject: [PATCH 4/5] Doc update --- docs/OperatorKernels.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 7cc57a636362f..9a13cd76e477a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -240,7 +240,8 @@ Do not modify directly.* |||[13, 15]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float)| +|LpNormalization|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(double), tensor(float)| +|||[1, 21]|**T** = tensor(double), tensor(float)| |LpPool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[18, 21]|**T** = tensor(float)| |||[11, 17]|**T** = tensor(float)| @@ -746,7 +747,9 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| -|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|22+|**T1** = tensor(float)
**T2** = tensor(float)| +|||[20, 21]|**T1** = tensor(float)
**T2** = tensor(float)| +|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |HardSwish|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -1061,7 +1064,9 @@ Do not modify directly.* |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| |GlobalAveragePool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GlobalMaxPool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|22+|**T1** = tensor(float)
**T2** = tensor(float)| +|||[20, 21]|**T1** = tensor(float)
**T2** = tensor(float)| +|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |MaxPool|*in* X:**T**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**
*out* Indices:**I**|12+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)| From 422d2b0d3d1dc74f38c6879cc76186b1f1864bc6 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 30 Jan 2026 18:32:22 -0800 Subject: [PATCH 5/5] Formatting --- .../onnx_transpose_optimization.cc | 2 +- .../core/providers/cuda/tensor/grid_sample.cc | 29 +++++++++---------- .../providers/cuda/tensor/grid_sample_impl.cu | 28 +++++++++--------- .../providers/cpu/tensor/grid_sample_test.cc | 4 +-- 4 files changed, 30 insertions(+), 33 deletions(-) diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 8fff4f03010da..f0b1a30b293ff 100755 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -748,7 +748,7 @@ std::vector ChannelLastToFirstPerm(size_t rank) { } std::vector p(rank); - p[0] = 0; // This is usually the batch dimension (hence preserve this position) + p[0] = 0; // This is usually the batch dimension (hence preserve this position) p[1] = rank - 1; for (size_t i = 2; i < rank; ++i) { p[i] = i - 1; diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc index 75fa711859193..b9d47a27e8e83 100755 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc @@ -23,16 +23,16 @@ namespace cuda { #define REGISTER_KERNEL_VERSIONED_TYPED(T, FROM_VERSION, TO_VERSION, LAYOUT, DOMAIN) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - GridSample, \ - DOMAIN, \ - FROM_VERSION, \ - TO_VERSION, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - onnxruntime::contrib::cuda::GridSample); + GridSample, \ + DOMAIN, \ + FROM_VERSION, \ + TO_VERSION, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + onnxruntime::contrib::cuda::GridSample); REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) @@ -110,15 +110,15 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { } if ((dims_input.size() == 4 && dims_grid[3] != 2) || (dims_input.size() == 5 && dims_grid[4] != 3)) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Last dimension of grid input must match the number of " - "spatial dimensions in the input (2 for 2D, 3 for 3D)."); + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Last dimension of grid input must match the number of " + "spatial dimensions in the input (2 for 2D, 3 for 3D)."); } if (dims_input.size() != 4 && dims_input.size() != 5) { return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Only 4-D and 5-D input tensors are supported"); } - if (dims_input.size() == 5 && mode_i_ == 2) { // This is common for CPU and CUDA to not support Cubic mode for 5D input // So it won't break CUDA users who were previously dropping down to CPU version of the op. @@ -152,7 +152,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { CudaT* Y_data = reinterpret_cast(Y->MutableData()); if (dims_input.size() == 4) { - // sample 2d + // sample 2d GridSampleImpl( Stream(context), reinterpret_cast(X->Data()), @@ -180,7 +180,6 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { Y_data); } - return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu index dd29f62e83c08..0e7d947741924 100755 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu @@ -262,8 +262,8 @@ SPECIALIZED_IMPL(float, false) // NCHW SPECIALIZED_IMPL(float, true) // NHWC template -__device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t z, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) { +__device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t z, int64_t y, int64_t x, + int64_t padding_mode, int64_t N, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) { T pixel = 0.0f; auto PixelOffset3D = [bIdx, cIdx, C, D, H, W](int64_t z, int64_t y, int64_t x) -> int64_t { @@ -308,7 +308,6 @@ __global__ void _GridSampleKernel3D( const int64_t H_out, const int64_t W_out, T* __restrict__ output_data) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * D_out * H_out * W_out); // extract batch index, channel index, y index, x index, z index for current thread @@ -327,7 +326,7 @@ __global__ void _GridSampleKernel3D( int tmpHCnt = tmpDCnt + yIdx * W_out; xIdx = (idx - tmpHCnt); - } else { // Layout == LAYOUT_NHWC + } else { // Layout == LAYOUT_NHWC BIdx = idx / (D_out * H_out * W_out * C); int tmpBCnt = BIdx * (D_out * H_out * W_out * C); @@ -343,7 +342,7 @@ __global__ void _GridSampleKernel3D( cIdx = (idx - tmpWCnt); } - int grid_idx = BIdx * D_out * H_out * W_out + zIdx * H_out * W_out + yIdx * W_out + xIdx; + int grid_idx = BIdx * D_out * H_out * W_out + zIdx * H_out * W_out + yIdx * W_out + xIdx; // ONNX spec guideline about ordering of grid indices: // Following computer vision convention, the coordinates in the length-r location vector @@ -381,7 +380,7 @@ __global__ void _GridSampleKernel3D( x_max = float(W_in) - 1.0f; } - float border[] = {z_min, y_min, x_min, z_max, y_max, x_max}; // zmin,ymin,xmin,zmax,ymax,xmax + float border[] = {z_min, y_min, x_min, z_max, y_max, x_max}; // zmin,ymin,xmin,zmax,ymax,xmax if (grid_z_volSpace < z_min || grid_z_volSpace > z_max || grid_y_volSpace < y_min || grid_y_volSpace > y_max || grid_x_volSpace < x_min || grid_x_volSpace > x_max) { // out of bound @@ -434,7 +433,7 @@ __global__ void _GridSampleKernel3D( w_lb_back = w_b * w_l * w_back; w_rb_back = w_b * w_r * w_back; - T lt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, N, C,D_in, H_in, W_in, border); + T lt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border); T rt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border); T lb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border); T rb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border); @@ -460,7 +459,7 @@ __global__ void _GridSampleKernel3D( return; } if (mode == 2) { // cubic - // Not implemented for 3D input. But will not reach here per input validation + // Not implemented for 3D input. But will not reach here per input validation } } @@ -477,7 +476,6 @@ void GridSampleImpl3D( const int64_t H_out, const int64_t W_out, T* output_data) { - int64_t N = 0; int64_t C = 0; int64_t D_in = 0; @@ -507,14 +505,14 @@ void GridSampleImpl3D( } template void GridSampleImpl3D(cudaStream_t stream, const float* input_data, const float* grid_data, - const int64_t mode, const int64_t padding_mode, const int64_t align_corners, - const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out, - float* output_data); + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, + const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out, + float* output_data); template void GridSampleImpl3D(cudaStream_t stream, const float* input_data, const float* grid_data, - const int64_t mode, const int64_t padding_mode, const int64_t align_corners, - const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out, - float* output_data); + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, + const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out, + float* output_data); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 18dc3a6d07446..217233fe89d23 100755 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -13,9 +13,9 @@ std::vector> GetExecutionProviders() { execution_providers.emplace_back(DefaultCpuExecutionProvider()); #ifdef USE_CUDA - execution_providers.emplace_back(DefaultCudaExecutionProvider()); + execution_providers.emplace_back(DefaultCudaExecutionProvider()); #ifdef ENABLE_CUDA_NHWC_OPS - execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); + execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); #endif #endif