diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2e6f329363a50..d46f3ed9bd262 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -759,6 +759,7 @@ Do not modify directly.* |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index a50b53315ec9a..0d9928baa86e0 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -20,7 +20,7 @@ namespace cuda { // float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions // CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2))) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) __device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); } __device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); } __device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); } @@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index aa60db4d07222..ad892eab3b843 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); @@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index f026444328b24..9ede1f8d90ecc 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) +UNARY_OP_BWUZCSILHFD(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) UNARY_OP_HFD(Round, 11) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 3ff97a60114df..775b78c43a736 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Sign final : public UnaryElementwise { + public: + Sign(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index ac7cc1126acb7..1298d53338337 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos) SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign) // When casting, half needs to be converted via float type from most other types template @@ -119,52 +120,52 @@ struct OP_Cast { } }; -#define IMPL_CAST_IMPL(InT, OutT) \ +#define IMPL_CAST_IMPL(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ - IMPL_CAST_IMPL(T, BFloat16) \ - IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ - IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ + IMPL_CAST_IMPL(T, BFloat16) \ + IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ + IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \ IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ) #else -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ IMPL_CAST_IMPL(T, BFloat16) #endif @@ -199,58 +200,58 @@ struct OP_CastNoSat { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ }; #else -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), false); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), false); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ + return T(v, true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, false); \ - } \ + return T(v, false); \ + } \ }; #endif @@ -260,14 +261,13 @@ struct OP_CastNoSat { OP_CAST(Float8E4M3FN, __NV_E4M3) OP_CAST(Float8E5M2, __NV_E5M2) - -#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ +#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \ - if (saturate) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ - } else { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ - } \ + if (saturate) { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ + } else { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ + } \ } EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 3d4868b54abe6..608a81a24cf4f 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -31,7 +31,8 @@ namespace cuda { UNARY_OP_NAME_EXPR(Not, !a) \ UNARY_OP_NAME_EXPR(Round, _Round(a)) \ UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \ - UNARY_OP_NAME_EXPR(Cos, _Cos(a)) + UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \ + UNARY_OP_NAME_EXPR(Sign, _Sign(a)) #define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \ diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5c516aac65aab..429ceb1f7c699 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); @@ -337,7 +349,7 @@ struct GridDim { }; // aligned vector generates vectorized load/store -template +template struct alignas(sizeof(T) * vec_size) aligned_vector { T val[vec_size]; }; @@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector { // HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. // TODO ROCM added support recently, should verify. #define HIP_KERNEL_ASSERT(...) -//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) +// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) // WARP related definitions and functions constexpr int GPU_WARP_SIZE = warpSize; -inline int GPU_WARP_SIZE_HOST= warpSizeDynamic(); +inline int GPU_WARP_SIZE_HOST = warpSizeDynamic(); template __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 61e46767e8f1a..c9975d0bc76c0 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1105,6 +1105,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); @@ -2067,6 +2078,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/math/sign_test.cc b/onnxruntime/test/providers/cpu/math/sign_test.cc index 12844068c47d2..15b3f40faa791 100644 --- a/onnxruntime/test/providers/cpu/math/sign_test.cc +++ b/onnxruntime/test/providers/cpu/math/sign_test.cc @@ -113,7 +113,7 @@ TestImpl(ForwardIter first, ForwardIter last, OutputIter out) { TEST(MathOpTest, Sign_uint64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -129,7 +129,7 @@ TEST(MathOpTest, Sign_uint64) { // we disable this test for openvino as openvino ep supports only FP32 Precision TEST(MathOpTest, Sign_int64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -146,7 +146,7 @@ TEST(MathOpTest, Sign_int64) { TEST(MathOpTest, Sign_float) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -162,7 +162,7 @@ TEST(MathOpTest, Sign_float) { TEST(MathOpTest, Sign_double) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -177,7 +177,7 @@ TEST(MathOpTest, Sign_double) { } TEST(MathOpTest, Sign_MLFloat16) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input;