diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 39c9145a40912..e6236fb5d1ea1 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -843,7 +843,12 @@ Do not modify directly.*
|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
-|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|25+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|||24|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|||23|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|||[21, 22]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
+|||18|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index b87cf8cbc16c1..de92dfddac6d9 100755
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1441,10 +1441,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, bool, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize);
@@ -1452,6 +1452,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize);
// Opset 19
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, bool, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool);
@@ -1573,6 +1577,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Transpose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, float, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, double, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, MLFloat16, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, bool, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, ConstantOfShape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, Identity);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, If);
@@ -1639,6 +1647,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast);
#endif
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, double, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Transpose);
@@ -1663,10 +1675,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Squeeze);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Unsqueeze);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, float, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, double, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, MLFloat16, Pad);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, bool, Pad);
// Opset 25.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Unsqueeze);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, float, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, double, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, MLFloat16, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, ConstantOfShape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, If);
@@ -2560,10 +2580,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2571,6 +2591,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
// Opset 19-20
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2641,6 +2665,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
// Opset 21
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// TODO(fajin): support other quantized types
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2730,6 +2758,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
// Opset 23
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2782,10 +2814,18 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 25
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc
index 9b23209953081..3dd50c1c03cbf 100644
--- a/onnxruntime/core/providers/cuda/tensor/pad.cc
+++ b/onnxruntime/core/providers/cuda/tensor/pad.cc
@@ -40,10 +40,70 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.TypeConstraint("T", DataTypeImpl::GetTensorType()), \
Pad); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Pad, \
+ kOnnxDomain, \
+ 18, 18, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Pad); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Pad, \
+ kOnnxDomain, \
+ 19, 20, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Pad); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Pad, \
+ kOnnxDomain, \
+ 21, 22, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Pad); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Pad, \
+ kOnnxDomain, \
+ 23, 23, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Pad); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Pad, \
+ kOnnxDomain, \
+ 24, 24, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Pad); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Pad, \
kOnnxDomain, \
- 18, \
+ 25, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
@@ -154,6 +214,11 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const {
effective_input_extents.push_back(extent);
}
+ TArray input_offsets(dimension_count);
+ for (int32_t i = 0; i < dimension_count; ++i) {
+ input_offsets[i] = -(*p_slices)[i];
+ }
+
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);
@@ -236,7 +301,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const {
return Status::OK();
}
- if (IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) {
+ if (mode_ != Mode::Wrap &&
+ IsNCHWInputWithPaddingAlongHAndW(dimension_count, lower_pads, upper_pads)) {
// If we have entered here, it means the input can only be 4-D (NCHW), 3-D (CHW), or 2-D (HW)
// NCHW input
@@ -282,6 +348,8 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const {
input_dims,
input_strides,
lower_pads,
+ TArray(effective_input_extents),
+ input_offsets,
value,
static_cast(mode_),
reinterpret_cast::MappedType*>(input_tensor.Data()),
diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu
index 6f530e800fdf2..6020769bf0ddf 100644
--- a/onnxruntime/core/providers/cuda/tensor/pad_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.cu
@@ -7,19 +7,27 @@
namespace onnxruntime {
namespace cuda {
-// PadMode enum from core/providers/cpu/tensor/pad.h, cannot use that header because of nvcc/onnxruntime incompatibility
+// PadMode enum from core/providers/cpu/tensor/padbase.h, cannot use that header because of nvcc/onnxruntime incompatibility
enum class PadMode : int {
Constant = 0,
Reflect,
- Edge
+ Edge,
+ Wrap
};
+__device__ __forceinline__ int64_t WrapCoordinate(int64_t coord, int64_t extent) {
+ int64_t wrapped = coord % extent;
+ return wrapped < 0 ? wrapped + extent : wrapped;
+}
+
template
__global__ void _PadKernel(
const size_t shape_rank,
const TArray input_dims,
const TArray input_strides,
const TArray lower_pads,
+ const TArray effective_input_extents,
+ const TArray input_offsets,
const T pad_value,
const T* input_data,
const TArray fdm_output_strides,
@@ -33,33 +41,44 @@ __global__ void _PadKernel(
int out_coord, r;
fdm_output_strides[dim].divmod(output_index, out_coord, r);
output_index = r;
- int in_coord = 0;
- if (out_coord < lower_pads[dim]) {
- switch ((PadMode)pad_mode) {
- case PadMode::Constant:
- use_pad_value = true;
- break;
- case PadMode::Edge:
- in_coord = 0;
- break;
- case PadMode::Reflect:
- in_coord = lower_pads[dim] - out_coord;
- break;
- }
- } else if (out_coord >= lower_pads[dim] + input_dims[dim]) {
- switch ((PadMode)pad_mode) {
- case PadMode::Constant:
- use_pad_value = true;
- break;
- case PadMode::Edge:
- in_coord = input_dims[dim] - 1;
- break;
- case PadMode::Reflect:
- in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim]));
- break;
- }
+ int64_t in_coord = 0;
+ if constexpr (pad_mode == static_cast(PadMode::Wrap)) {
+ const int64_t effective_input_extent = effective_input_extents[dim];
+ const int64_t pre_pad = lower_pads[dim] + input_offsets[dim];
+ const int64_t relative_coord = static_cast(out_coord) - pre_pad;
+ in_coord = input_offsets[dim] + WrapCoordinate(relative_coord, effective_input_extent);
} else {
- in_coord = out_coord - lower_pads[dim];
+ if (out_coord < lower_pads[dim]) {
+ switch ((PadMode)pad_mode) {
+ case PadMode::Constant:
+ use_pad_value = true;
+ break;
+ case PadMode::Edge:
+ in_coord = 0;
+ break;
+ case PadMode::Reflect:
+ in_coord = lower_pads[dim] - out_coord;
+ break;
+ case PadMode::Wrap:
+ break;
+ }
+ } else if (out_coord >= lower_pads[dim] + input_dims[dim]) {
+ switch ((PadMode)pad_mode) {
+ case PadMode::Constant:
+ use_pad_value = true;
+ break;
+ case PadMode::Edge:
+ in_coord = input_dims[dim] - 1;
+ break;
+ case PadMode::Reflect:
+ in_coord = input_dims[dim] - 2 - (out_coord - (lower_pads[dim] + input_dims[dim]));
+ break;
+ case PadMode::Wrap:
+ break;
+ }
+ } else {
+ in_coord = out_coord - lower_pads[dim];
+ }
}
input_index += input_strides[dim] * in_coord;
}
@@ -136,6 +155,8 @@ void PadImpl(
const TArray& input_dims,
const TArray& input_strides,
const TArray& lower_pads,
+ const TArray& effective_input_extents,
+ const TArray& input_offsets,
const T pad_value,
const int pad_mode,
const T* input_data,
@@ -149,17 +170,22 @@ void PadImpl(
switch (pad_mode) {
case 0:
_PadKernel<<>>(
- shape_rank, input_dims, input_strides, lower_pads,
+ shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets,
pad_value, input_data, fdm_output_strides, output_data, N);
break;
case 1:
_PadKernel<<>>(
- shape_rank, input_dims, input_strides, lower_pads,
+ shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets,
pad_value, input_data, fdm_output_strides, output_data, N);
break;
case 2:
_PadKernel<<>>(
- shape_rank, input_dims, input_strides, lower_pads,
+ shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets,
+ pad_value, input_data, fdm_output_strides, output_data, N);
+ break;
+ case 3:
+ _PadKernel<<>>(
+ shape_rank, input_dims, input_strides, lower_pads, effective_input_extents, input_offsets,
pad_value, input_data, fdm_output_strides, output_data, N);
break;
}
@@ -211,6 +237,8 @@ void PadNCHWInputWithPaddingAlongHAndWImpl(
template void PadImpl(cudaStream_t stream, const size_t shape_rank, \
const TArray& input_dims, const TArray& input_strides, \
const TArray& lower_pads, \
+ const TArray& effective_input_extents, \
+ const TArray& input_offsets, \
const T pad_value, \
const int pad_mode, \
const T* input_data, \
diff --git a/onnxruntime/core/providers/cuda/tensor/pad_impl.h b/onnxruntime/core/providers/cuda/tensor/pad_impl.h
index dc700ea2304e9..96f158dd187fc 100644
--- a/onnxruntime/core/providers/cuda/tensor/pad_impl.h
+++ b/onnxruntime/core/providers/cuda/tensor/pad_impl.h
@@ -32,6 +32,8 @@ void PadImpl(
const TArray& input_dims,
const TArray& input_strides,
const TArray& lower_pads,
+ const TArray& effective_input_extents,
+ const TArray& input_offsets,
const T pad_value,
const int pad_mode,
const T* input_data,
diff --git a/onnxruntime/test/providers/cpu/tensor/pad_test.cc b/onnxruntime/test/providers/cpu/tensor/pad_test.cc
index 9169f2e6b5ca9..990e4354c3626 100644
--- a/onnxruntime/test/providers/cpu/tensor/pad_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/pad_test.cc
@@ -124,6 +124,37 @@ static void RunAllOpsetAllDomainPadTests(
}
}
+#ifdef USE_CUDA
+template
+static void RunCudaOnlyOnnxOpsetPadTest(
+ int opset,
+ const std::vector& input_dims,
+ const std::vector& input,
+ const std::vector& pads,
+ T value,
+ const std::vector& output_dims,
+ const std::vector& output,
+ const std::string& mode = "constant") {
+ auto cuda_execution_provider = DefaultCudaExecutionProvider();
+ if (cuda_execution_provider == nullptr) {
+ GTEST_SKIP() << "CUDA execution provider is not available";
+ }
+
+ OpTester test("Pad", opset);
+ if (mode != "constant") {
+ test.AddAttribute("mode", mode);
+ }
+ test.AddInput("data", input_dims, input);
+ test.AddInput("pads", {static_cast(pads.size())}, pads, true);
+ test.AddInput("value", {}, {value}, true);
+ test.AddOutput("output", output_dims, output);
+
+ std::vector> execution_providers;
+ execution_providers.emplace_back(std::move(cuda_execution_provider));
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+#endif
+
// Some of the tests can't run on TensorrtExecutionProvider because only constant mode and value 0 of "Pad" node is supported.
// Those tests will fallback to other EP.
@@ -199,6 +230,48 @@ TYPED_TEST(PadOpTest, Pad_Edge_1D) {
"edge");
}
+#ifdef USE_CUDA
+TEST(PadOpTest, Pad_Edge_CudaOnly_MLFloat16_SupportedOpsets) {
+ const std::vector supported_opsets{18, 19, 20, 21, 22, 23, 24, 25};
+ for (int opset : supported_opsets) {
+ SCOPED_TRACE(MakeString("opset: ", opset));
+ RunCudaOnlyOnnxOpsetPadTest(
+ opset,
+ {3, 2},
+ {MLFloat16(1.0f), MLFloat16(2.0f),
+ MLFloat16(3.0f), MLFloat16(4.0f),
+ MLFloat16(5.0f), MLFloat16(6.0f)},
+ {0, 2, 0, 1},
+ MLFloat16(0.0f),
+ {3, 5},
+ {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f),
+ MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(4.0f),
+ MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(6.0f)},
+ "edge");
+ }
+}
+
+TEST(PadOpTest, Pad_Wrap_CudaOnly_Float_SupportedOpsets) {
+ const std::vector supported_opsets{19, 20, 21, 22, 23, 24, 25};
+ for (int opset : supported_opsets) {
+ SCOPED_TRACE(MakeString("opset: ", opset));
+ RunCudaOnlyOnnxOpsetPadTest(
+ opset,
+ {3, 2},
+ {1.0f, 2.0f,
+ 3.0f, 4.0f,
+ 5.0f, 6.0f},
+ {0, 1, 0, 1},
+ 0.0f,
+ {3, 4},
+ {2.0f, 1.0f, 2.0f, 1.0f,
+ 4.0f, 3.0f, 4.0f, 3.0f,
+ 6.0f, 5.0f, 6.0f, 5.0f},
+ "wrap");
+ }
+}
+#endif
+
TYPED_TEST(PadOpTest, Pad_Constant_2D) {
using T = TypeParam;
RunAllOpsetAllDomainPadTests({2, 2},
@@ -1391,9 +1464,7 @@ TEST(PadOpTest, Pad_Wrap_NegativeFront_PositiveBack) {
// Post-slice core: [4]; wrap 3 -> [4, 4, 4, 4]
const std::vector expected_data = {4, 4, 4, 4};
- // CUDA registers only up to 18 and does not impl wrap mode
- // so we force version to 19 to automatically exclude EPs that do not
- // implement wrap mode similar to the above tests.
+ // Use opset 19 to exercise wrap mode, which is supported from Pad-19 onward.
OpTester test("Pad", 19);
test.AddInput("data", input_shape, input_data);
test.AddInput("pads", {static_cast(pads.size())}, pads, true);