From a578c0a1567c8f4cd16add2399d39f98d9332109 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Mar 2026 14:25:09 -0700 Subject: [PATCH] CUDA MaxPool-22 --- .../providers/cuda/cuda_execution_provider.cc | 30 ++++++++++++------- .../core/providers/cuda/cuda_nhwc_kernels.cc | 24 ++++++++++----- onnxruntime/core/providers/cuda/nn/pool.cc | 28 +++++++++++------ .../test/providers/cpu/nn/pool_op_test.cc | 30 +++++++++++++++++++ .../test/providers/cuda/nhwc/pool_test.cc | 4 ++- 5 files changed, 88 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 60ac16018f539..ffad7cf520619 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1037,11 +1037,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom // OpSet 12 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int8_t, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, uint8_t, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 21, float, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 21, double, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 21, MLFloat16, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 21, int8_t, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 21, uint8_t, MaxPool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Pow); @@ -1579,6 +1579,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, AveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, AveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, int8_t, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, uint8_t, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv); @@ -2123,11 +2128,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // OpSet 12 BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2663,6 +2668,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index 8239a8ac252e6..78a02926a950a 100755 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -65,10 +65,14 @@ class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, AveragePool); class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, MLFloat16, AveragePool); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, float, MaxPool); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, MLFloat16, MaxPool); -class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, float, MaxPool); -class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, MLFloat16, MaxPool); -class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, int8_t, MaxPool); -class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, uint8_t, MaxPool); +class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(12, 21, float, MaxPool); +class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(12, 21, MLFloat16, MaxPool); +class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(12, 21, int8_t, MaxPool); +class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(12, 21, uint8_t, MaxPool); +class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, MaxPool); +class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, MLFloat16, MaxPool); +class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, int8_t, MaxPool); +class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, uint8_t, MaxPool); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, float, BatchNormalization); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, double, BatchNormalization); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, MLFloat16, BatchNormalization); @@ -130,10 +134,14 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index f5fb851e5a061..3878c04dfb694 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -72,11 +72,17 @@ POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10, kO POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kOnnxDomain, false) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 11, 11, kOnnxDomain, false) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kOnnxDomain, false) -POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kOnnxDomain, false) -POOLING_KERNEL_WITH_INDICES(MaxPool, double, MaxPool<8>, 12, kOnnxDomain, false) -POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kOnnxDomain, false) -POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, kOnnxDomain, false) -POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, kOnnxDomain, false) +// MaxPool opsets 12-22 share the same CUDA implementation for the currently supported types. +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, 21, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 12, 21, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, 21, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, 21, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, 21, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 22, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, double, MaxPool<8>, 22, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 22, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 22, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 22, kOnnxDomain, false) POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kOnnxDomain, false) POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1, kOnnxDomain, false) @@ -92,10 +98,14 @@ POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10, kMSInt POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10, kMSInternalNHWCDomain, true) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) -POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kMSInternalNHWCDomain, true) -POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kMSInternalNHWCDomain, true) -POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true) -POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, 21, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, 21, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, 21, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, 21, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 22, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 22, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 22, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 22, kMSInternalNHWCDomain, true) POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kMSInternalNHWCDomain, true) POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kMSInternalNHWCDomain, true) diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index c7a3526d9f030..83910fa4de685 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -195,6 +195,36 @@ TEST(PoolTest, MaxPool_8_With_Index) { MaxPool_8_WithIndexTest(true, 1 /*storage_order*/); // col major } +TEST(PoolTest, MaxPool_22_With_Index_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("MaxPool", 22); + test.AddAttribute("strides", std::vector{2, 2}); + test.AddAttribute("pads", vector{0, 0, 0, 0}); + test.AddAttribute("kernel_shape", vector{2, 2}); + + std::vector x_vals = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f}; + std::vector x_dims = {1, 1, 4, 4}; + std::vector expected_dims = {1, 1, 2, 2}; + std::vector expected_vals = {6.0f, 8.0f, 14.0f, 16.0f}; + std::vector expected_indices = {5, 7, 13, 15}; + + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.AddOutput("Indices", expected_dims, expected_indices); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(PoolTest, MaxPool1D_case1) { OpTester test("MaxPool"); diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc index 426170b9588f1..877da668345fd 100644 --- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc @@ -10,6 +10,7 @@ namespace test { template struct PoolOp { std::string pooling_type; + int opset = 14; std::vector input_dims; std::vector kernel_shape; int64_t channels; @@ -20,7 +21,7 @@ struct PoolOp { std::unique_ptr get_test() { RandomValueGenerator random{}; - auto test = std::make_unique(pooling_type.c_str(), 14); + auto test = std::make_unique(pooling_type.c_str(), opset); std::vector input_data = random.Uniform(input_dims, 0.0f, 0.3f); test->AddInput("X", input_dims, input_data); @@ -53,6 +54,7 @@ TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwc) { TYPED_TEST(CudaNhwcTypedTest, MaxPoolNhwc) { auto op = PoolOp{}; op.pooling_type = "MaxPool"; + op.opset = 22; op.input_dims = {1, 16, 64, 64}; op.kernel_shape = {3, 3}; op.channels = 16;