Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ Do not modify directly.*
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(double), tensor(float), tensor(float16)|
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
|Size|*in* data:**T**<br> *out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace cuda {

// float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions
// CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2)))
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
__device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); }
__device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); }
__device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); }
Expand Down Expand Up @@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }

template <typename T>
__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }

template <typename T>
__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }

template <typename T>
__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed<T>()); }

template <>
__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }

template <typename T>
__device__ __inline__ T _Normcdf(T a);

Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);

class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
Expand Down Expand Up @@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13)
UNARY_OP_HFD(Log, 13)
UNARY_OP_HFD(Exp, 13)
UNARY_OP_HFD(Erf, 13)
UNARY_OP_BWUZCSILHFD(Sign, 13)

UNARY_LOGICALOP_NOT_TYPED(1, bool)
UNARY_OP_HFD(Round, 11)
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};

template <typename T>
class Sign final : public UnaryElementwise {
public:
Sign(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

} // namespace cuda
} // namespace onnxruntime
170 changes: 85 additions & 85 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign)

// When casting, half needs to be converted via float type from most other types
template <typename T>
Expand Down Expand Up @@ -119,52 +120,52 @@ struct OP_Cast {
}
};

#define IMPL_CAST_IMPL(InT, OutT) \
#define IMPL_CAST_IMPL(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \
}

#define IMPL_CAST_IMPL_THROW(InT, OutT) \
#define IMPL_CAST_IMPL_THROW(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
}

#if !defined(DISABLE_FLOAT8_TYPES)

#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ)

#else

#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16)

#endif
Expand Down Expand Up @@ -199,58 +200,58 @@ struct OP_CastNoSat {

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080

#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
};

#else

#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), true); \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), false); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), true); \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), false); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(v, true); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
return T(v, true); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(v, false); \
} \
return T(v, false); \
} \
};

#endif
Expand All @@ -260,14 +261,13 @@ struct OP_CastNoSat {
OP_CAST(Float8E4M3FN, __NV_E4M3)
OP_CAST(Float8E5M2, __NV_E5M2)


#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \
if (saturate) { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat<InT, OutT>(), count); \
} else { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat<InT, OutT>(), count); \
} \
if (saturate) { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat<InT, OutT>(), count); \
} else { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat<InT, OutT>(), count); \
} \
}

EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace cuda {
UNARY_OP_NAME_EXPR(Not, !a) \
UNARY_OP_NAME_EXPR(Round, _Round(a)) \
UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \
UNARY_OP_NAME_EXPR(Cos, _Cos(a))
UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \
UNARY_OP_NAME_EXPR(Sign, _Sign(a))

#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \
Expand Down
18 changes: 15 additions & 3 deletions onnxruntime/core/providers/rocm/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }

template <typename T>
__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }

template <typename T>
__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }

template <typename T>
__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed<T>()); }

template <>
__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }

template <typename T>
__device__ __inline__ T _Normcdf(T a);

Expand Down Expand Up @@ -337,7 +349,7 @@ struct GridDim {
};

// aligned vector generates vectorized load/store
template<typename T, int vec_size>
template <typename T, int vec_size>
struct alignas(sizeof(T) * vec_size) aligned_vector {
T val[vec_size];
};
Expand All @@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector {
// HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels.
// TODO ROCM added support recently, should verify.
#define HIP_KERNEL_ASSERT(...)
//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)

// WARP related definitions and functions
constexpr int GPU_WARP_SIZE = warpSize;
inline int GPU_WARP_SIZE_HOST= warpSizeDynamic();
inline int GPU_WARP_SIZE_HOST = warpSizeDynamic();

template <typename T>
__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) {
Expand Down
Loading