From 89d697a06010578389a07778ab735c3d89dd1f9b Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 19 Feb 2024 11:23:10 +0100 Subject: [PATCH 1/9] Implement support for NHWC GridSample in the CUDA EP. --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 7 ++ onnxruntime/contrib_ops/cuda/grid_sample.cc | 31 ++++--- onnxruntime/contrib_ops/cuda/grid_sample.h | 2 +- .../contrib_ops/cuda/grid_sample_impl.cu | 89 ++++++++++++------- .../contrib_ops/cuda/grid_sample_impl.h | 2 +- .../layout_transformation.cc | 2 + .../providers/cuda/cuda_provider_factory.cc | 3 + .../providers/cuda/shared_inc/cuda_utils.h | 26 ++++++ 8 files changed, 113 insertions(+), 49 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index be8c0dc86c135..5ff4471edbc89 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -203,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze); #endif +#ifdef ENABLE_CUDA_NHWC_OPS +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 20, float, GridSample); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -408,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif +#ifdef ENABLE_CUDA_NHWC_OPS + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc index 4c2999c279e0a..3363b2807822c 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.cc +++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc @@ -9,22 +9,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GridSample, \ - kMSDomain, \ - 1, \ + DOMAIN, \ + VERSION, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); + GridSample); -REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) +REGISTER_KERNEL_TYPED(float, 20, LAYOUT_NHWC, kMSInternalNHWCDomain) -template -GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { +template +GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); @@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { } } -template -Status GridSample::ComputeInternal(OpKernelContext* context) const { +template +Status GridSample::ComputeInternal(OpKernelContext* context) const { const Tensor* X = context->Input(0); const auto& dims_input = X->Shape().GetDims(); const Tensor* Grid = context->Input(1); @@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]); ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2"); + using Ch = Channels; + TensorShapeVector dims_output(4); - dims_output[0] = dims_input[0]; - dims_output[1] = dims_input[1]; - dims_output[2] = dims_grid[1]; - dims_output[3] = dims_grid[2]; + dims_output[Ch::N] = dims_input[Ch::N]; + dims_output[Ch::C] = dims_input[Ch::C]; + dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; + dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; Tensor* Y = context->Output(0, dims_output); // Return early if the output tensor is going to be of size 0 if (Y->Shape().Size() == 0) { @@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; CudaT* Y_data = reinterpret_cast(Y->MutableData()); - GridSampleImpl( + GridSampleImpl( Stream(context), reinterpret_cast(X->Data()), reinterpret_cast(Grid->Data()), diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h index 08ca58c7cc458..16581bfe77482 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample.h @@ -12,7 +12,7 @@ namespace cuda { using namespace onnxruntime::cuda; -template +template class GridSample final : public CudaKernel { public: explicit GridSample(const OpKernelInfo& info); diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index 8a391eca7e86a..d7efb371ba4ee 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -50,28 +50,32 @@ __device__ T GsReflect(T x, float x_min, float x_max) { return static_cast(fx); } -template +template __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { + int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { T pixel = 0.0f; + + auto PixelOffset = [bIdx, cIdx, N, C, H, W](int64_t x, int64_t y) -> int64_t { + return Layout == LAYOUT_NCHW ? (bIdx * C * H * W + cIdx * H * W + y * W + x) : (bIdx * H * W * C + y * W * C + x * C + cIdx); + }; + if (padding_mode == 0) { // zeros if (x >= 0 && x < W && y >= 0 && y < H) { - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } - } else if (padding_mode == 1) { //border + } else if (padding_mode == 1) { // border x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x)); y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y)); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } else { // Reflection - x = (int64_t) GsReflect(x, border[0], border[2]); - y = (int64_t) GsReflect(y, border[1], border[3]); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + x = (int64_t)GsReflect(x, border[0], border[2]); + y = (int64_t)GsReflect(y, border[1], border[3]); + pixel = input_data[PixelOffset(x, y)]; } return pixel; } -__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) -{ +__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) { float cubic_alpha = -0.75f; x = abs(x); coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha); @@ -93,7 +97,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) { return pixel; } -template +template __global__ void _GridSampleKernel( const T* input_data, const T* grid_data, @@ -110,16 +114,32 @@ __global__ void _GridSampleKernel( { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); // extract batch index, channel index, y index, x index for current thread - int BIdx = idx / (C * H_out * W_out ); - int tmpBCnt = BIdx * (C * H_out * W_out); + int BIdx, yIdx, xIdx, cIdx; + if constexpr (Layout == LAYOUT_NCHW) { + BIdx = idx / (C * H_out * W_out); + int tmpBCnt = BIdx * (C * H_out * W_out); + + cIdx = (idx - tmpBCnt) / (H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); - int cIdx = (idx - tmpBCnt) / (H_out * W_out); - int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); + yIdx = (idx - tmpCCnt) / W_out; + int tmpHCnt = tmpCCnt + yIdx * W_out; - int yIdx = (idx - tmpCCnt) / W_out; - int tmpHCnt = tmpCCnt + yIdx * W_out; + xIdx = (idx - tmpHCnt); + } else { + static_assert(Layout == LAYOUT_NHWC, "Unsupported layout"); - int xIdx = (idx - tmpHCnt); + BIdx = idx / (H_out * W_out * C); + int tmpBCnt = BIdx * (H_out * W_out * C); + + yIdx = (idx - tmpBCnt) / (W_out * C); + int tmpHCnt = tmpBCnt + yIdx * (W_out * C); + + xIdx = (idx - tmpHCnt) / C; + int tmpWCnt = tmpHCnt + xIdx * C; + + cIdx = (idx - tmpWCnt); + } int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; T grid_X = grid_data[grid_idx * 2 + 0]; @@ -175,10 +195,10 @@ __global__ void _GridSampleKernel( w_lb = w_b * w_l; w_rb = w_b * w_r; - T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); - T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); - T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); - T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); + T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); + T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); + T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); + T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; output_data[outIdx] = interpoV; return; @@ -186,7 +206,7 @@ __global__ void _GridSampleKernel( if (mode == 1) { // nearest int x_n = grid_x_imgSpace; int y_n = grid_y_imgSpace; - output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); + output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); return; } if (mode == 2) { // bicubic @@ -195,7 +215,7 @@ __global__ void _GridSampleKernel( T p[4][4] = {}; // [H][W] for (int64_t h = 0; h < 4; h++) { for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); + p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); } } T dx = grid_x_imgSpace - x0 - 1; @@ -204,7 +224,7 @@ __global__ void _GridSampleKernel( } } -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, @@ -216,17 +236,20 @@ void GridSampleImpl( const int64_t H_out, const int64_t W_out, T* output_data) { - int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock)); - _GridSampleKernel<<>>( - input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data); + using Ch = Channels; + + int blocksPerGrid = (int)(ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); + _GridSampleKernel<<>>( + input_data, grid_data, mode, padding_mode, align_corners, dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], H_out, W_out, output_data); } -#define SPECIALIZED_IMPL(T) \ - template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ - const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ - const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); +#define SPECIALIZED_IMPL(T, IsNHWC) \ + template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ + const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); -SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(float, false) // NCHW +SPECIALIZED_IMPL(float, true) // NHWC } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h index 6df86ce161908..62cd66a48fa84 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 4505d4afdf1e0..0c3f8147a1c6e 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a } #if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS +// TODO generate list from registered kernels using nhwc domain const std::unordered_set& GetCUDALayoutSensitiveOps() { static std::unordered_set cuda_nhwc_ops = []() { return std::unordered_set{ @@ -41,6 +42,7 @@ const std::unordered_set& GetCUDALayoutSensitiveOps() { "MaxPool", "GlobalAveragePool", "AveragePool", + "GridSample", }; }(); return cuda_nhwc_ops; diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 103c79c93b2ca..f2391d9196d67 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -219,6 +219,9 @@ struct CUDA_Provider : Provider { info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0; info.enable_cuda_graph = params->enable_cuda_graph != 0; info.prefer_nhwc = params->prefer_nhwc; + // HACK + info.prefer_nhwc = true; + //info.prefer_nhwc = false; info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index fa987866c002f..94eb9fc4e2967 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -168,5 +168,31 @@ struct NumericLimits { } }; +// TODO Where to put this? good places might be +// core/framework/tensor_shape.h +// core/util/matrix_layout.h + +constexpr bool LAYOUT_NCHW = false; +constexpr bool LAYOUT_NHWC = true; + +template +struct Channels; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t H = 1; + static constexpr size_t W = 2; + static constexpr size_t C = 3; +}; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t C = 1; + static constexpr size_t H = 2; + static constexpr size_t W = 3; +}; + } // namespace cuda } // namespace onnxruntime From e88f1b2fb65c132d92dd1b0ebe184fe8dfedca22 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 19 Feb 2024 15:32:37 +0100 Subject: [PATCH 2/9] Fix issue #10607, grid_sample CUDA kernel clamping is incorrect for border padding mode with align = 1. --- onnxruntime/contrib_ops/cuda/grid_sample_impl.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index d7efb371ba4ee..a8a9bf7cb5e45 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -167,8 +167,9 @@ __global__ void _GridSampleKernel( if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound if (padding_mode == 1) { // border - grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); - grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); + // Clamping must not be done here, see #10607 + //grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); + //grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); } else if (padding_mode == 2) { // reflection grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); From dac6e760e5a4dcf1c21033be6ad09f24621a2ae8 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 19 Feb 2024 15:34:41 +0100 Subject: [PATCH 3/9] Update grid_sample_test.cc to run on all execution providers & change GridSample NHWC version to 16. --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 +- onnxruntime/contrib_ops/cuda/grid_sample.cc | 8 +- .../providers/cuda/cuda_execution_provider.cc | 2 + .../providers/cpu/tensor/grid_sample_test.cc | 154 +++++++----------- .../cpu/tensor/grid_sample_test_gen.py | 2 +- 5 files changed, 69 insertions(+), 101 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 5ff4471edbc89..57e951d3a68ff 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -204,7 +204,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 #endif #ifdef ENABLE_CUDA_NHWC_OPS -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 20, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample); #endif template <> @@ -413,7 +413,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { #endif #ifdef ENABLE_CUDA_NHWC_OPS - BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc index 3363b2807822c..5df799701c2cd 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.cc +++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc @@ -19,10 +19,10 @@ namespace cuda { (*KernelDefBuilder::Create()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); + onnxruntime::contrib::cuda::GridSample); REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) -REGISTER_KERNEL_TYPED(float, 20, LAYOUT_NHWC, kMSInternalNHWCDomain) +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain) template GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { @@ -92,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { } } // namespace cuda } // namespace contrib + +namespace cuda { + REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) +} // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 48a952e6dd98f..449f9577809ab 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1256,6 +1256,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -2143,6 +2144,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 0f097622abff0..9f7cdbed5035f 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -6,6 +6,15 @@ namespace onnxruntime { namespace test { + +static std::unordered_set GetExcludedExecutionProviders(int version) { + std::unordered_set excluded_providers; + if (version >= 20) { + excluded_providers.insert(kCudaExecutionProvider); + } + return excluded_providers; +} + // DO NOT edit following tests. They are generated by: // onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { @@ -25,8 +34,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { @@ -46,8 +54,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { @@ -67,8 +74,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { @@ -88,8 +94,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { @@ -109,8 +114,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { @@ -130,8 +134,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { @@ -151,8 +154,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { @@ -172,8 +174,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { @@ -193,8 +194,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { @@ -214,8 +214,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { @@ -235,8 +234,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { @@ -256,8 +254,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { @@ -277,8 +274,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { @@ -298,8 +294,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { @@ -319,8 +314,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { @@ -340,8 +334,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { @@ -361,8 +354,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { @@ -382,8 +374,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { @@ -403,8 +394,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { @@ -424,8 +414,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { @@ -445,8 +434,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { @@ -466,8 +454,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { @@ -487,8 +474,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { @@ -508,8 +494,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { @@ -529,8 +514,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { @@ -550,8 +534,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { @@ -571,8 +554,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { @@ -592,8 +574,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { @@ -613,8 +594,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { @@ -634,8 +614,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { @@ -655,8 +634,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { @@ -676,8 +654,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { @@ -697,8 +674,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { @@ -718,8 +694,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { @@ -739,8 +714,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { @@ -760,8 +734,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { @@ -781,8 +754,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { @@ -802,8 +774,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { @@ -823,8 +794,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { @@ -844,8 +814,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { @@ -865,8 +834,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { @@ -886,8 +854,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { @@ -907,8 +874,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { @@ -928,8 +894,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { @@ -949,8 +914,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { @@ -970,8 +934,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { @@ -991,8 +954,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { @@ -1012,8 +974,8 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); } + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index e4d58e79243ef..1cac440fa7605 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -76,6 +76,6 @@ print('test.AddAttribute("padding_mode", padding_mode);') print('test.AddAttribute("align_corners", align_corners);') print('test.AddOutput("Y", Y_shape, Y_data);') - print("test.Run();") + print(f'test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders({opset_version}));') print("}") print("\n") From 16bc36558a735ab17196f2441bddcd5fc547d2b5 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 19 Feb 2024 19:15:45 +0100 Subject: [PATCH 4/9] Run lintrunner & remove nhwc hack. --- onnxruntime/contrib_ops/cuda/grid_sample.cc | 8 ++++---- onnxruntime/core/providers/cuda/cuda_provider_factory.cc | 3 --- onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h | 2 +- .../test/providers/cpu/tensor/grid_sample_test_gen.py | 4 +++- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc index 5df799701c2cd..2500de39d3536 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.cc +++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc @@ -9,11 +9,11 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \ +#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GridSample, \ DOMAIN, \ - VERSION, \ + VERSION, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -68,7 +68,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { dims_output[Ch::N] = dims_input[Ch::N]; dims_output[Ch::C] = dims_input[Ch::C]; dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; - dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; + dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; Tensor* Y = context->Output(0, dims_output); // Return early if the output tensor is going to be of size 0 if (Y->Shape().Size() == 0) { @@ -94,6 +94,6 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { } // namespace contrib namespace cuda { - REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index f2391d9196d67..103c79c93b2ca 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -219,9 +219,6 @@ struct CUDA_Provider : Provider { info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0; info.enable_cuda_graph = params->enable_cuda_graph != 0; info.prefer_nhwc = params->prefer_nhwc; - // HACK - info.prefer_nhwc = true; - //info.prefer_nhwc = false; info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index 94eb9fc4e2967..54c024793ff0b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -168,7 +168,7 @@ struct NumericLimits { } }; -// TODO Where to put this? good places might be +// TODO Where to put this? good places might be // core/framework/tensor_shape.h // core/util/matrix_layout.h diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index 1cac440fa7605..f83ecbec14f0a 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -76,6 +76,8 @@ print('test.AddAttribute("padding_mode", padding_mode);') print('test.AddAttribute("align_corners", align_corners);') print('test.AddOutput("Y", Y_shape, Y_data);') - print(f'test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders({opset_version}));') + print( + f'test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders({opset_version}));' + ) print("}") print("\n") From 7431b20649171159faf20b3dfc5b8f4763d5e3ca Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 21 Feb 2024 16:50:08 +0100 Subject: [PATCH 5/9] Add functionality for inclusive list to run grid_sample tests --- .../providers/cpu/tensor/grid_sample_test.cc | 124 ++++++++++-------- .../cpu/tensor/grid_sample_test_gen.py | 4 +- onnxruntime/test/util/default_providers.cc | 14 ++ .../test/util/include/default_providers.h | 3 + 4 files changed, 89 insertions(+), 56 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 9f7cdbed5035f..1c82ec1fd7ee3 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -7,12 +7,30 @@ namespace onnxruntime { namespace test { -static std::unordered_set GetExcludedExecutionProviders(int version) { - std::unordered_set excluded_providers; - if (version >= 20) { - excluded_providers.insert(kCudaExecutionProvider); +std::vector> GetExecutionProviders(int opset_version) { + std::vector> execution_providers; + + execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + if (opset_version < 20) { + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#ifdef ENABLE_CUDA_NHWC_OPS + execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); +#endif } - return excluded_providers; + +#endif + return execution_providers; +} + +template +void RunTests(T& test, std::vector>& execution_providers) { + for (size_t idx = 0; idx < execution_providers.size(); ++idx) { + std::vector> execution_provider; + execution_provider.push_back(std::move(execution_providers[idx])); + test.ConfigEps(std::move(execution_provider)).RunWithConfig(); + } + execution_providers.clear(); } // DO NOT edit following tests. They are generated by: @@ -34,7 +52,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { @@ -54,7 +72,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { @@ -74,7 +92,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { @@ -94,7 +112,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { @@ -114,7 +132,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { @@ -134,7 +152,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { @@ -154,7 +172,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { @@ -174,7 +192,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { @@ -194,7 +212,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { @@ -214,7 +232,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { @@ -234,7 +252,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { @@ -254,7 +272,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { @@ -274,7 +292,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { @@ -294,7 +312,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { @@ -314,7 +332,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { @@ -334,7 +352,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { @@ -354,7 +372,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { @@ -374,7 +392,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(16)); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { @@ -394,7 +412,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { @@ -414,7 +432,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { @@ -434,7 +452,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { @@ -454,7 +472,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { @@ -474,7 +492,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { @@ -494,7 +512,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { @@ -514,7 +532,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { @@ -534,7 +552,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { @@ -554,7 +572,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { @@ -574,7 +592,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { @@ -594,7 +612,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { @@ -614,7 +632,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { @@ -634,7 +652,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { @@ -654,7 +672,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { @@ -674,7 +692,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { @@ -694,7 +712,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { @@ -714,7 +732,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { @@ -734,7 +752,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { @@ -754,7 +772,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { @@ -774,7 +792,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { @@ -794,7 +812,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { @@ -814,7 +832,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { @@ -834,7 +852,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { @@ -854,7 +872,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { @@ -874,7 +892,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { @@ -894,7 +912,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { @@ -914,7 +932,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { @@ -934,7 +952,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { @@ -954,7 +972,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { @@ -974,7 +992,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders(20)); + RunTests(test, GetExecutionProviders(20)); } } // namespace test diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index f83ecbec14f0a..c60e55617774f 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -76,8 +76,6 @@ print('test.AddAttribute("padding_mode", padding_mode);') print('test.AddAttribute("align_corners", align_corners);') print('test.AddOutput("Y", Y_shape, Y_data);') - print( - f'test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders({opset_version}));' - ) + print(f"RunTests(test, GetExecutionProviders({opset_version}));") print("}") print("\n") diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 40b40136af1af..6e2c49d5baac6 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -8,6 +8,9 @@ #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif +#if defined(ENABLE_CUDA_NHWC_OPS) +#include +#endif #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/session_options.h" @@ -118,6 +121,17 @@ std::unique_ptr DefaultCudaExecutionProvider() { return nullptr; } +std::unique_ptr DefaultCudaNHWCExecutionProvider() { +#ifdef USE_CUDA + OrtCUDAProviderOptionsV2 provider_options{}; + provider_options.do_copy_in_default_stream = true; + provider_options.prefer_nhwc = true; + if (auto factory = CudaProviderFactoryCreator::Create(&provider_options)) + return factory->CreateProvider(); +#endif + return nullptr; +} + std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options) { #ifdef USE_CUDA if (auto factory = CudaProviderFactoryCreator::Create(provider_options)) diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 9f78e0a0d4eb2..738fc66d775c6 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -35,6 +35,9 @@ namespace test { // unique_ptr providers with default values for session registration std::unique_ptr DefaultCpuExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultCudaExecutionProvider(); +#ifdef ENABLE_CUDA_NHWC_OPS +std::unique_ptr DefaultCudaNHWCExecutionProvider(); +#endif std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options); std::unique_ptr DefaultDnnlExecutionProvider(); std::unique_ptr DnnlExecutionProviderWithOptions(const OrtDnnlProviderOptions* provider_options); From 79b8827ac76c28c23ee145f49ce1a0e474ae817e Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 21 Feb 2024 23:49:41 +0100 Subject: [PATCH 6/9] Fix lvalue issue and protect DefaultCudaNHWCExecutionProvider() by NHWC ifdef --- onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc | 6 ++---- onnxruntime/test/util/default_providers.cc | 4 +++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 1c82ec1fd7ee3..b7761835927c6 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -24,11 +24,9 @@ std::vector> GetExecutionProviders(int opset } template -void RunTests(T& test, std::vector>& execution_providers) { +void RunTests(T& test, std::vector>&& execution_providers) { for (size_t idx = 0; idx < execution_providers.size(); ++idx) { - std::vector> execution_provider; - execution_provider.push_back(std::move(execution_providers[idx])); - test.ConfigEps(std::move(execution_provider)).RunWithConfig(); + test.ConfigEp(std::move(execution_providers[idx])).RunWithConfig(); } execution_providers.clear(); } diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 6e2c49d5baac6..b404c12db3582 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -121,8 +121,9 @@ std::unique_ptr DefaultCudaExecutionProvider() { return nullptr; } +#ifdef ENABLE_CUDA_NHWC_OPS std::unique_ptr DefaultCudaNHWCExecutionProvider() { -#ifdef USE_CUDA +#if defined(USE_CUDA) OrtCUDAProviderOptionsV2 provider_options{}; provider_options.do_copy_in_default_stream = true; provider_options.prefer_nhwc = true; @@ -131,6 +132,7 @@ std::unique_ptr DefaultCudaNHWCExecutionProvider() { #endif return nullptr; } +#endif std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options) { #ifdef USE_CUDA From 2b5a5db848f986b665755fc23c4bfe18ed116a25 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Thu, 22 Feb 2024 19:06:59 +0100 Subject: [PATCH 7/9] Fix lintrunner issues & unused variable without CUDA build. --- .../contrib_ops/cuda/grid_sample_impl.cu | 21 ++++++++++++------- .../layout_transformation.cc | 2 +- .../providers/cpu/tensor/grid_sample_test.cc | 2 ++ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index a8a9bf7cb5e45..fd3603424ffbd 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -56,7 +56,9 @@ __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_ T pixel = 0.0f; auto PixelOffset = [bIdx, cIdx, N, C, H, W](int64_t x, int64_t y) -> int64_t { - return Layout == LAYOUT_NCHW ? (bIdx * C * H * W + cIdx * H * W + y * W + x) : (bIdx * H * W * C + y * W * C + x * C + cIdx); + return Layout == LAYOUT_NCHW + ? (bIdx * C * H * W + cIdx * H * W + y * W + x) + : (bIdx * H * W * C + y * W * C + x * C + cIdx); }; if (padding_mode == 0) { // zeros @@ -168,8 +170,8 @@ __global__ void _GridSampleKernel( grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound if (padding_mode == 1) { // border // Clamping must not be done here, see #10607 - //grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); - //grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); + // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); + // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); } else if (padding_mode == 2) { // reflection grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); @@ -207,7 +209,8 @@ __global__ void _GridSampleKernel( if (mode == 1) { // nearest int x_n = grid_x_imgSpace; int y_n = grid_y_imgSpace; - output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); + output_data[outIdx] = + PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); return; } if (mode == 2) { // bicubic @@ -216,7 +219,8 @@ __global__ void _GridSampleKernel( T p[4][4] = {}; // [H][W] for (int64_t h = 0; h < 4; h++) { for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); + p[h][w] = + PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); } } T dx = grid_x_imgSpace - x0 - 1; @@ -239,9 +243,12 @@ void GridSampleImpl( T* output_data) { using Ch = Channels; - int blocksPerGrid = (int)(ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); + int blocksPerGrid = static_cast + ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); _GridSampleKernel<<>>( - input_data, grid_data, mode, padding_mode, align_corners, dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], H_out, W_out, output_data); + input_data, grid_data, mode, padding_mode, align_corners, + dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], + H_out, W_out, output_data); } #define SPECIALIZED_IMPL(T, IsNHWC) \ diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 0c3f8147a1c6e..a8717b99a8750 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -31,7 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a } #if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS -// TODO generate list from registered kernels using nhwc domain +// TODO(mtavenrath) generate list from registered kernels using nhwc domain const std::unordered_set& GetCUDALayoutSensitiveOps() { static std::unordered_set cuda_nhwc_ops = []() { return std::unordered_set{ diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index b7761835927c6..5c89d6ea7bd75 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -8,6 +8,8 @@ namespace onnxruntime { namespace test { std::vector> GetExecutionProviders(int opset_version) { + ORT_UNUSED_PARAMETER(opset_version); + std::vector> execution_providers; execution_providers.emplace_back(DefaultCpuExecutionProvider()); From 58a81b9d636980d813f900330320069c5cff0a72 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Thu, 22 Feb 2024 20:15:40 +0100 Subject: [PATCH 8/9] Add missing ( --- onnxruntime/contrib_ops/cuda/grid_sample_impl.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index fd3603424ffbd..8553f271b1235 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -243,7 +243,7 @@ void GridSampleImpl( T* output_data) { using Ch = Channels; - int blocksPerGrid = static_cast + int blocksPerGrid = static_cast( ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); _GridSampleKernel<<>>( input_data, grid_data, mode, padding_mode, align_corners, From c91413c87337e1b0045b24e6dfa808fa85f4d2c3 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Thu, 22 Feb 2024 22:46:57 +0100 Subject: [PATCH 9/9] Updating docs & fix lambda capture --- docs/OperatorKernels.md | 1 + onnxruntime/contrib_ops/cuda/grid_sample_impl.cu | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8ff2135c6b1f6..ef8fccc37902a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -617,6 +617,7 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index 8553f271b1235..b23da635bc83d 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -55,7 +55,7 @@ __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_ int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { T pixel = 0.0f; - auto PixelOffset = [bIdx, cIdx, N, C, H, W](int64_t x, int64_t y) -> int64_t { + auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t { return Layout == LAYOUT_NCHW ? (bIdx * C * H * W + cIdx * H * W + y * W + x) : (bIdx * H * W * C + y * W * C + x * C + cIdx);