From 3dc473e765d880b08a0cea50272ff7de9aea299f Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 23 Feb 2026 00:18:38 +0800 Subject: [PATCH] Add pad op version 19 to 23 support for CUDA --- .../providers/cuda/cuda_execution_provider.cc | 40 +++++-- onnxruntime/core/providers/cuda/tensor/pad.cc | 50 ++++++++- .../core/providers/cuda/tensor/pad_impl.cu | 102 ++++++++++++++++-- .../core/providers/cuda/tensor/pad_impl.h | 6 ++ .../test/providers/cpu/tensor/pad_test.cc | 3 - 5 files changed, 181 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index bf6fcc7ccf0a8..5ea2a15c188c8 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1441,10 +1441,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize); @@ -1452,6 +1452,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize); // Opset 19 +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, bool, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool); @@ -1572,6 +1576,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, bool, Pad); // Opset 22. class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, AveragePool); @@ -1621,6 +1629,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast); #endif +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Transpose); @@ -2519,10 +2531,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2530,6 +2542,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 19-20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2599,6 +2615,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 21 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // TODO(fajin): support other quantized types BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2671,6 +2691,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 961ce9ffd721b..215d4c9adb26a 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -40,10 +40,46 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 18, 18, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 19, 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 21, 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 18, \ + 23, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -154,6 +190,12 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { effective_input_extents.push_back(extent); } + TArray slice_starts(dimension_count); + for (size_t i = 0; i < dimension_count; ++i) { + slice_starts[i] = (*p_slices)[i]; + } + TArray effective_input_dims(effective_input_extents); + TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); @@ -248,6 +290,10 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { output_dims[width_dim], lower_pads[height_dim], lower_pads[width_dim], + slice_starts[height_dim], + slice_starts[width_dim], + effective_input_dims[height_dim], + effective_input_dims[width_dim], value, static_cast(mode_), reinterpret_cast::MappedType*>(input_tensor.Data()), @@ -269,6 +315,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { input_dims, input_strides, lower_pads, + slice_starts, + effective_input_dims, value, static_cast(mode_), reinterpret_cast::MappedType*>(input_tensor.Data()), diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu index 6f530e800fdf2..e5fd700f12c78 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu @@ -7,11 +7,12 @@ namespace onnxruntime { namespace cuda { -// PadMode enum from core/providers/cpu/tensor/pad.h, cannot use that header because of nvcc/onnxruntime incompatibility +// PadMode enum from core/providers/cpu/tensor/padbase.h, cannot use that header because of nvcc/onnxruntime incompatibility enum class PadMode : int { Constant = 0, Reflect, - Edge + Edge, + Wrap }; template @@ -20,6 +21,8 @@ __global__ void _PadKernel( const TArray input_dims, const TArray input_strides, const TArray lower_pads, + const TArray slice_starts, + const TArray effective_input_dims, const T pad_value, const T* input_data, const TArray fdm_output_strides, @@ -45,6 +48,20 @@ __global__ void _PadKernel( case PadMode::Reflect: in_coord = lower_pads[dim] - out_coord; break; + case PadMode::Wrap: { + int64_t eff_len = effective_input_dims[dim]; + if (eff_len == 0) { + // Should not happen if output size > 0, but safe fallback + in_coord = 0; + } else { + // Match CPU: effective_index = (eff_len - lower_pads[dim] + out_coord) % eff_len + // then in_coord = -slice_starts[dim] + effective_index (map to input via effective region) + int64_t effective_index = (eff_len - lower_pads[dim] + static_cast(out_coord)) % eff_len; + if (effective_index < 0) effective_index += eff_len; + in_coord = static_cast(-slice_starts[dim] + effective_index); + } + break; + } } } else if (out_coord >= lower_pads[dim] + input_dims[dim]) { switch ((PadMode)pad_mode) { @@ -57,6 +74,18 @@ __global__ void _PadKernel( case PadMode::Reflect: in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim])); break; + case PadMode::Wrap: { + int64_t eff_len = effective_input_dims[dim]; + if (eff_len == 0) { + in_coord = 0; + } else { + // Match CPU: effective_index = (eff_len - lower_pads[dim] + out_coord) % eff_len + int64_t effective_index = (eff_len - lower_pads[dim] + static_cast(out_coord)) % eff_len; + if (effective_index < 0) effective_index += eff_len; + in_coord = static_cast(-slice_starts[dim] + effective_index); + } + break; + } } } else { in_coord = out_coord - lower_pads[dim]; @@ -76,6 +105,10 @@ __global__ void _PadNCHWInputWithPaddingAlongHAndWKernel( const int64_t output_width, const int64_t pad_height_start, const int64_t pad_width_start, + const int64_t slice_height_start, + const int64_t slice_width_start, + const int64_t effective_input_height, + const int64_t effective_input_width, const T pad_value, const T* input_data, T* output_data, @@ -126,6 +159,32 @@ __global__ void _PadNCHWInputWithPaddingAlongHAndWKernel( input_width + current_input_width]; break; + + case PadMode::Wrap: { + int64_t h_in, w_in; + + if (effective_input_height == 0) { + h_in = 0; + } else { + // Match CPU: effective_index = (eff_len - lower_pads + out_coord) % eff_len + int64_t effective_index_h = (effective_input_height - pad_height_start + current_output_height) % effective_input_height; + if (effective_index_h < 0) effective_index_h += effective_input_height; + h_in = -slice_height_start + effective_index_h; + } + + if (effective_input_width == 0) { + w_in = 0; + } else { + // Match CPU: effective_index = (eff_len - lower_pads + out_coord) % eff_len + int64_t effective_index_w = (effective_input_width - pad_width_start + current_output_width) % effective_input_width; + if (effective_index_w < 0) effective_index_w += effective_input_width; + w_in = -slice_width_start + effective_index_w; + } + + output_data[id] = input_data[(nc_index * input_height + static_cast(h_in)) * input_width + + static_cast(w_in)]; + break; + } } } @@ -136,6 +195,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& slice_starts, + const TArray& effective_input_dims, const T pad_value, const int pad_mode, const T* input_data, @@ -149,17 +210,22 @@ void PadImpl( switch (pad_mode) { case 0: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, slice_starts, effective_input_dims, pad_value, input_data, fdm_output_strides, output_data, N); break; case 1: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, slice_starts, effective_input_dims, pad_value, input_data, fdm_output_strides, output_data, N); break; case 2: _PadKernel<<>>( - shape_rank, input_dims, input_strides, lower_pads, + shape_rank, input_dims, input_strides, lower_pads, slice_starts, effective_input_dims, + pad_value, input_data, fdm_output_strides, output_data, N); + break; + case 3: + _PadKernel<<>>( + shape_rank, input_dims, input_strides, lower_pads, slice_starts, effective_input_dims, pad_value, input_data, fdm_output_strides, output_data, N); break; } @@ -176,6 +242,10 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( const int64_t output_width, const int64_t pad_height_start, const int64_t pad_width_start, + const int64_t slice_height_start, + const int64_t slice_width_start, + const int64_t effective_input_height, + const int64_t effective_input_width, const T pad_value, const int pad_mode, const T* input_data, @@ -189,19 +259,29 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( case 0: _PadNCHWInputWithPaddingAlongHAndWKernel<<>>( n, c, input_height, output_height, input_width, output_width, - pad_height_start, pad_width_start, + pad_height_start, pad_width_start, slice_height_start, slice_width_start, + effective_input_height, effective_input_width, pad_value, input_data, output_data, N); break; case 1: _PadNCHWInputWithPaddingAlongHAndWKernel<<>>( n, c, input_height, output_height, input_width, output_width, - pad_height_start, pad_width_start, + pad_height_start, pad_width_start, slice_height_start, slice_width_start, + effective_input_height, effective_input_width, pad_value, input_data, output_data, N); break; case 2: _PadNCHWInputWithPaddingAlongHAndWKernel<<>>( n, c, input_height, output_height, input_width, output_width, - pad_height_start, pad_width_start, + pad_height_start, pad_width_start, slice_height_start, slice_width_start, + effective_input_height, effective_input_width, + pad_value, input_data, output_data, N); + break; + case 3: + _PadNCHWInputWithPaddingAlongHAndWKernel<<>>( + n, c, input_height, output_height, input_width, output_width, + pad_height_start, pad_width_start, slice_height_start, slice_width_start, + effective_input_height, effective_input_width, pad_value, input_data, output_data, N); break; } @@ -211,6 +291,8 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( template void PadImpl(cudaStream_t stream, const size_t shape_rank, \ const TArray& input_dims, const TArray& input_strides, \ const TArray& lower_pads, \ + const TArray& slice_starts, \ + const TArray& effective_input_dims, \ const T pad_value, \ const int pad_mode, \ const T* input_data, \ @@ -222,6 +304,10 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( const int64_t input_width, const int64_t output_width, \ const int64_t pad_height_start, \ const int64_t pad_width_start, \ + const int64_t slice_height_start, \ + const int64_t slice_width_start, \ + const int64_t effective_input_height, \ + const int64_t effective_input_width, \ const T pad_value, \ const int pad_mode, \ const T* input_data, T* output_data, \ diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.h b/onnxruntime/core/providers/cuda/tensor/pad_impl.h index dc700ea2304e9..9cd6df8368f8e 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.h @@ -19,6 +19,10 @@ void PadNCHWInputWithPaddingAlongHAndWImpl( const int64_t output_width, const int64_t pad_height_start, const int64_t pad_width_start, + const int64_t slice_height_start, + const int64_t slice_width_start, + const int64_t effective_input_height, + const int64_t effective_input_width, const T pad_value, const int pad_mode, const T* input_data, @@ -32,6 +36,8 @@ void PadImpl( const TArray& input_dims, const TArray& input_strides, const TArray& lower_pads, + const TArray& slice_starts, + const TArray& effective_input_dims, const T pad_value, const int pad_mode, const T* input_data, diff --git a/onnxruntime/test/providers/cpu/tensor/pad_test.cc b/onnxruntime/test/providers/cpu/tensor/pad_test.cc index a18ff6dabbae0..87cc566639a0d 100644 --- a/onnxruntime/test/providers/cpu/tensor/pad_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/pad_test.cc @@ -1391,9 +1391,6 @@ TEST(PadOpTest, Pad_Wrap_NegativeFront_PositiveBack) { // Post-slice core: [4]; wrap 3 -> [4, 4, 4, 4] const std::vector expected_data = {4, 4, 4, 4}; - // CUDA registers only up to 18 and does not impl wrap mode - // so we force version to 19 to automatically exclude EPs that do not - // implement wrap mode similar to the above tests. OpTester test("Pad", 19); test.AddInput("data", input_shape, input_data); test.AddInput("pads", {static_cast(pads.size())}, pads, true);