Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,8 @@ Do not modify directly.*
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|19+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||18|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
Expand Down
30 changes: 20 additions & 10 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1443,11 +1443,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, int32_t, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize);

// Opset 19
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool);
Expand Down Expand Up @@ -1500,6 +1500,11 @@ class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear);
#endif
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Reshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Shape);

Expand Down Expand Up @@ -2513,11 +2518,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, double, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, MLFloat16, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, int32_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize)>,

// Opset 19-20
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool)>,
Expand Down Expand Up @@ -2572,6 +2577,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
#endif

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Shape)>,

Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,22 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()), \
Resize<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Resize, \
kOnnxDomain, \
18, 18, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()), \
Resize<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Resize, \
kOnnxDomain, \
18, \
19, \
Comment thread
hariharans29 marked this conversation as resolved.
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,8 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \
CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \
TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \
CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::HALF_PIXEL_SYMMETRIC, \
TransformCoordinate_HALF_PIXEL_SYMMETRIC, __VA_ARGS__) \
default: \
ORT_THROW("unknown ResizeCoordinateTransformationMode"); \
} \
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/resize_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ struct NearestPixel_CEIL {
TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \
CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \
TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \
CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL_SYMMETRIC, \
TransformCoordinate_HALF_PIXEL_SYMMETRIC, __VA_ARGS__) \
default: \
ORT_THROW("unknown ResizeCoordinateTransformationMode"); \
} \
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/resize_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE {
}
};

struct TransformCoordinate_HALF_PIXEL_SYMMETRIC {
Comment thread
hariharans29 marked this conversation as resolved.
__device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float length_resized,
float length_original, float, float) const {
float output_width = x_scale * length_original;
float adjustment = length_resized / output_width;
float center = length_original / 2.0f;
float offset = center * (1.0f - adjustment);
return offset + ((x_resized + 0.5f) / x_scale) - 0.5f;
}
};

size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode,
const gsl::span<const int64_t>& output_dims);

Expand Down
Loading