diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index cc18ece351705..591d2c6806ea7 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -875,7 +875,8 @@ Do not modify directly.*
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)|
|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)|
|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**
or
*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**
or
*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|19+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|||18|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index eb29e4edbf897..b4cb6c6bd122c 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1443,11 +1443,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, int32_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize);
// Opset 19
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool);
@@ -1500,6 +1500,11 @@ class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear);
#endif
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Reshape);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Shape);
@@ -2513,11 +2518,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 19-20
BuildKernelCreateInfo,
@@ -2572,6 +2577,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
#endif
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/resize.cc b/onnxruntime/core/providers/cuda/tensor/resize.cc
index 97d4eb71e970a..4f38b50b43c76 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize.cc
+++ b/onnxruntime/core/providers/cuda/tensor/resize.cc
@@ -40,10 +40,22 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType()), \
Resize); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Resize, \
+ kOnnxDomain, \
+ 18, 18, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 3) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \
+ Resize); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Resize, \
kOnnxDomain, \
- 18, \
+ 19, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
index 9c9a27360b479..965b02d9445fc 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
@@ -657,6 +657,8 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \
CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \
TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \
+ CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::HALF_PIXEL_SYMMETRIC, \
+ TransformCoordinate_HALF_PIXEL_SYMMETRIC, __VA_ARGS__) \
default: \
ORT_THROW("unknown ResizeCoordinateTransformationMode"); \
} \
diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu
index a96d4c82a7fdc..6bc66b38697b5 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu
@@ -71,6 +71,8 @@ struct NearestPixel_CEIL {
TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \
CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \
TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \
+ CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL_SYMMETRIC, \
+ TransformCoordinate_HALF_PIXEL_SYMMETRIC, __VA_ARGS__) \
default: \
ORT_THROW("unknown ResizeCoordinateTransformationMode"); \
} \
diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.h b/onnxruntime/core/providers/cuda/tensor/resize_impl.h
index 6e960be4ec09c..fdde5840edbc5 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize_impl.h
+++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.h
@@ -65,6 +65,17 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE {
}
};
+struct TransformCoordinate_HALF_PIXEL_SYMMETRIC {
+ __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float length_resized,
+ float length_original, float, float) const {
+ float output_width = x_scale * length_original;
+ float adjustment = length_resized / output_width;
+ float center = length_original / 2.0f;
+ float offset = center * (1.0f - adjustment);
+ return offset + ((x_resized + 0.5f) / x_scale) - 0.5f;
+ }
+};
+
size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode,
const gsl::span& output_dims);