diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 75b975c2359b4..bde73252449dc 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -510,7 +510,10 @@ else() if (onnxruntime_USE_SVE) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlas_sve_fp16.h) set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") list(APPEND mlas_private_compile_definitions MLAS_USE_SVE) endif() @@ -548,6 +551,10 @@ else() ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/erf_neon_fp16.h + ${MLAS_SRC_DIR}/erf_neon_fp16.cpp + ${MLAS_SRC_DIR}/gelu_neon_fp16.h + ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) if (onnxruntime_USE_ARM_NEON_NCHWC) list(APPEND mlas_platform_srcs @@ -577,6 +584,8 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() + set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -988,4 +997,4 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) endif() endif() -endif() +endif() \ No newline at end of file diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index f5e5343376f65..f7c2908d0ab8b 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2237,3 +2237,72 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); + +/** + * @brief Enumeration of supported GELU algorithm variants. + * + * MlasGeluErf - Exact GELU implementation using the error function (erf). + * MlasGeluTanh - Approximate GELU implementation using tanh-based formulation. + */ +typedef enum MLAS_GELU_ALGORITHM { + MlasGeluErf = 0, + MlasGeluTanh = 1 +} MLAS_GELU_ALGORITHM; + +/** + * @brief Computes element-wise FP16 error function (erf). + * + * This routine computes: + * Output[i] = erf(Input[i]) + * for N elements. Depending on platform capabilities, this may use + * vectorized FP16 intrinsics or fall back to a scalar FP32 conversion path. + * + * @param Input Pointer to input buffer of N FP16 elements. + * @param Output Pointer to output buffer of N FP16 elements. + * @param Input_tmp_fp32 Pointer to caller-allocated scratch buffer of N floats + * for FP32 input conversion (used only on fallback path). + * @param Output_tmp_fp32 Pointer to caller-allocated scratch buffer of N floats + * for FP32 output conversion (used only on fallback path). + * @param N Number of elements to process. + */ +void +MLASCALL +MlasComputeFP16Erf( + const MLAS_FP16* Input, + MLAS_FP16* Output, + float* Input_tmp_fp32, + float* Output_tmp_fp32, + size_t N +); + +/** + * @brief Computes element-wise FP16 GELU activation. + * + * This routine computes: + * + * If algo == MlasGeluTanh (approximate): + * GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + * + * If algo == MlasGeluErf (exact): + * GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + * + * Depending on platform capabilities, this may use vectorized FP16 kernels + * (SVE/NEON) or fall back to a scalar FP32 conversion path. + * + * @param input Pointer to input buffer of FP16 elements. + * @param output Pointer to output buffer of FP16 elements. + * @param temp Temporary scratch buffer of at least 'count' FP16 elements. + * Required by certain vectorized implementations. May be unused + * in scalar fallback paths. + * @param count Number of elements to process. + * @param algo GELU algorithm variant (exact erf or tanh approximation). + */ +void +MLASCALL +MlasComputeFP16Gelu( + const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + size_t count, + MLAS_GELU_ALGORITHM algo +); diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp index f9724062e1f4d..04a7c67a8ef10 100644 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -266,3 +266,23 @@ Return Value: MlasErfKernel(Input, Output, N); #endif } + +void +MLASCALL +MlasComputeFP16Erf( + const MLAS_FP16* Input, + MLAS_FP16* Output, + float* Input_tmp_fp32, + float* Output_tmp_fp32, + size_t N + ) +{ + if(GetMlasPlatform().ErfFP16KernelRoutine){ + GetMlasPlatform().ErfFP16KernelRoutine(Input, Output, N); + return; + } + + MlasConvertHalfToFloatBuffer(Input, Input_tmp_fp32, N); + MlasComputeErf(Input_tmp_fp32, Output_tmp_fp32, N); + MlasConvertFloatToHalfBuffer(Output_tmp_fp32, Output, N); +} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp new file mode 100644 index 0000000000000..7973c8d1a7db0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.cpp @@ -0,0 +1,156 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + erf_neon_fp16.cpp + +Abstract: + + This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. + +--*/ + +#include "erf_neon_fp16.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + +using _mlas_fp16_ = uint16_t; +// Helpers to safely convert between float and FP16-bit representation +static float +fp16_to_float(uint16_t h) +{ + __fp16 tmp; + std::memcpy(&tmp, &h, sizeof(h)); + return (float)tmp; +} + +static uint16_t +float_to_fp16(float f) +{ + __fp16 tmp = (__fp16)f; + uint16_t h; + std::memcpy(&h, &tmp, sizeof(h)); + return h; +} + +static inline MLAS_FLOAT16X8 +exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x) +{ + const float16_t a0 = 6.0f; + MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0); + x = MlasMinimumFloat16(x, max_x); + + const float16_t c0 = 1.330f; + const float16_t c1 = -0.390f; + const float16_t c2 = 0.0288f; + + const float16_t d0 = 1.338f; + const float16_t d1 = 0.848f; + const float16_t d2 = 0.467f; + + MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0); + MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1); + MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2); + + MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0); + MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1); + MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); + MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v); + num = MlasMultiplyAddFloat16(c2v, x2, num); + MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v); + den = MlasMultiplyAddFloat16(d2v, x2, den); + MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den); + recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); + recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); + MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip); + return result; +} + +void +MlasNeonErfFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) +{ + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + const float16_t p = 0.328f; + const float16_t a1 = 0.2505f; + const float16_t a2 = -0.2881f; + const float16_t a3 = 1.4102f; + const float16_t a4 = -1.423f; + const float16_t a5 = 1.0547f; + + MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p); + MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1); + MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2); + MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3); + MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4); + MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5); + + constexpr float16_t one_fp16 = 1.0f; + constexpr float16_t neg_one_fp16 = -1.0f; + constexpr float16_t zero_fp16 = 0.0f; + constexpr float16_t four_fp16 = 4.0f; + + MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16); + MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16); + MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16); + MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16); + + size_t i = 0; + for (; i + 8 <= N; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&input[i]); + MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero); + MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone); + MLAS_FLOAT16X8 absx = MlasAbsFloat16(x); + MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth); + MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth); + MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone); + MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom); + t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); + t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); + MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t); + MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t); + MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t); + MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t); + MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t); + poly = MlasMultiplyAddFloat16(va2, t2, poly); + poly = MlasMultiplyAddFloat16(va3, t3, poly); + poly = MlasMultiplyAddFloat16(va4, t4, poly); + poly = MlasMultiplyAddFloat16(va5, t5, poly); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped); + MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2); + MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2); + MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp); + MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term); + erf_approx = MlasMinimumFloat16(erf_approx, vone); + erf_approx = MlasMaximumFloat16(erf_approx, vneg_one); + MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign); + MlasStoreFloat16x8(&output[i], result); + } + + for (; i < N; i++) { + float x = fp16_to_float(input[i]); + float sign = (x < 0) ? -1.0f : 1.0f; + float absx = fabsf(x); + + if (absx > 4.0f) { + output[i] = float_to_fp16(sign); + continue; + } + + float t = 1.0f / (1.0f + p * absx); + float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t; + float exp_neg_x2 = expf(-absx * absx); + float erf_approx = sign * (1.0f - poly * exp_neg_x2); + if (erf_approx > 1.0f) erf_approx = 1.0f; + if (erf_approx < -1.0f) erf_approx = -1.0f; + + output[i] = float_to_fp16(erf_approx); + } +} +#endif diff --git a/onnxruntime/core/mlas/lib/erf_neon_fp16.h b/onnxruntime/core/mlas/lib/erf_neon_fp16.h new file mode 100644 index 0000000000000..7918df0ea3d1e --- /dev/null +++ b/onnxruntime/core/mlas/lib/erf_neon_fp16.h @@ -0,0 +1,27 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + erf_neon_fp16.h + +Abstract: + + This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" +#include "fp16_common.h" +#include "softmax_kernel_neon.h" +#include + +void MlasNeonErfFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index d4713cce5a176..52d57daca67a4 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -19,7 +19,7 @@ Module Name: #include "mlas_float16.h" #include "mlasi.h" -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) // TODO!! Add intel fp16 implementations @@ -579,4 +579,56 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector) return vshl_n_s16(Vector, ShiftCount); } -#endif // fp16 vector intrinsic supported +// NEON FP16 vector intrinsics +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasBroadcastF16Float16x8(float16_t Value) { return vdupq_n_f16(Value); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasLoadf16Float16x8(const float16_t* Buffer) { return vld1q_f16(Buffer); } + +MLAS_FORCEINLINE +void +MlasStoref16Float16x8(float16_t* Buffer, MLAS_FLOAT16X8 Vector) +{ + vst1q_f16(Buffer, Vector); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasReciprocalStepFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vrecpsq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasApproximateReciprocalFloat16(MLAS_FLOAT16X8 Vector) +{ + return vrecpeq_f16(Vector); +} + +MLAS_FORCEINLINE +MLAS_UINT16X8 +MlasCompareLessThanFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vcltq_f16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasAbsFloat16(MLAS_FLOAT16X8 Vector) +{ + return vabsq_f16(Vector); +} + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasSelectFloat16(MLAS_UINT16X8 Vector, MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +{ + return vbslq_f16(Vector, Vector1, Vector2); +} +#endif // NEON FP16 vector intrinsics supported +#endif // mlas fp16 intrinsic supported \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index dc25611652c77..855b8a282f48e 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -22,7 +22,6 @@ constexpr float kInvSqrt2 = 0.70710678118654752440f; } // namespace - void MLASCALL MlasGeluErfKernel( @@ -63,3 +62,37 @@ MlasComputeGeluErf( MlasGeluErfKernel(Input, Output, N); #endif } + +void +MLASCALL +MlasComputeFP16Gelu(const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + size_t count, + MLAS_GELU_ALGORITHM algo) +{ + if(GetMlasPlatform().GeluFP16KernelRoutine){ + GetMlasPlatform().GeluFP16KernelRoutine(input, output, temp, count, algo); + return; + } + MLAS_UNREFERENCED_PARAMETER(temp); // 'temp' is only used by vectorized kernel implementations and it is unused in the scalar fallback path. + for (size_t i = 0; i < count; ++i) { + float x = static_cast(input[i]); + float gelu_val; + + if (algo == MlasGeluTanh) { + // GELU approximation (tanh) + const float B = 0.7978845608f; + const float C = 0.044715f * B; + float tanh_arg = x * (B + C * x * x); + float tanh_res = std::tanh(tanh_arg); + gelu_val = 0.5f * x * (1.0f + tanh_res); + } else { + // GELU exact (erf) + gelu_val = 0.5f * x * + (1.0f + std::erf(x * static_cast(M_SQRT1_2))); + } + + output[i] = MLAS_FP16(gelu_val); + } +} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp new file mode 100644 index 0000000000000..0deb532318482 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu_neon_fp16.cpp @@ -0,0 +1,96 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu_neon_fp16.cpp + +Abstract: + + This module contains Gelu helper functions . + +--*/ +#include "gelu_neon_fp16.h" +#include +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + +void +MLASCALL +MlasNeonGeluFP16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, size_t count, MLAS_GELU_ALGORITHM algo) +{ + const float16_t v_half1 = 0.5f; + const float16_t v_one1 = 1.0f; + const float16_t v_sqrt1_21 = static_cast(M_SQRT1_2); + const float16_t v_B1 = 0.7978845608028654f; + const float16_t v_C1 = 0.035677408136300125f; + const float16_t c1 = 5.0f; + const float16_t c2 = -5.0f; + const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8(v_half1); + const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8(v_one1); + const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8(v_sqrt1_21); + const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8(v_B1); + const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8(v_C1); + + size_t i = 0; + + if (algo == MlasGeluTanh) { + // Preprocess input into temp[] for tanh + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); + MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16(v_C, x2, v_B); // B + C * x^2 + MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16(x, inner); // x * (B + C * x^2) + tanh_arg = MlasMaximumFloat16(MlasBroadcastF16Float16x8(c2), MlasMinimumFloat16(tanh_arg, MlasBroadcastF16Float16x8(c1))); + MlasStoref16Float16x8(reinterpret_cast(temp + i), tanh_arg); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + float inner = x * (0.7978845608028654f + 0.035677408136300125f * x * x); + inner = std::max(-5.0f, std::min(5.0f, inner)); + temp[i] = static_cast(inner); + } + + // Tanh processing + MlasComputeTanh(temp, temp, count); + + } else{ + // Preprocess input into temp[] for erf + for (i = 0; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16(x, v_sqrt1_2); + MlasStoref16Float16x8(reinterpret_cast(temp + i), scaled); + } + + // Tail + for (; i < count; ++i) { + float x = static_cast(input[i]); + temp[i] = static_cast(x * static_cast(M_SQRT1_2)); + } + + // Erf processing + MlasNeonErfFP16Kernel(temp, temp, count); + } + + // Final GELU output = 0.5 * x * (1 + tanh|erf) + i = 0; + for (; i + 7 < count; i += 8) { + MLAS_FLOAT16X8 x = MlasLoadf16Float16x8(reinterpret_cast(input + i)); + MLAS_FLOAT16X8 t = MlasLoadf16Float16x8(reinterpret_cast(temp + i)); + MLAS_FLOAT16X8 result = MlasMultiplyFloat16(v_half, MlasMultiplyFloat16(x, MlasAddFloat16(v_one, t))); + MlasStoref16Float16x8(reinterpret_cast(output + i), result); + } + + for (; i < count; ++i) { + float x = static_cast(input[i]); + float t = static_cast(temp[i]); + float gelu = 0.5f * x * (1.0f + t); + output[i] = static_cast(gelu); + } +} +#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/gelu_neon_fp16.h b/onnxruntime/core/mlas/lib/gelu_neon_fp16.h new file mode 100644 index 0000000000000..faabf561132f8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu_neon_fp16.h @@ -0,0 +1,31 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu_neon_fp16.h + +Abstract: + + This module contains Gelu helper functions . + +--*/ + +#pragma once + +#include "fp16_common.h" +#include "erf_neon_fp16.h" + +void +MLASCALL +MlasNeonGeluFP16Kernel( + const MLAS_FP16* input, + MLAS_FP16* output, + MLAS_FP16* temp, + size_t count, + MLAS_GELU_ALGORITHM algo +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b374caa9664f8..662e757a47998 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,6 +610,31 @@ void size_t N ); +typedef +void +(MLASCALL MLAS_COMPUTE_ERF_FP16_KERNEL)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + +typedef +void +(MLASCALL MLAS_COMPUTE_GELU_FP16_KERNEL)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + MLAS_FP16* Temp, + size_t N, + MLAS_GELU_ALGORITHM Algo +); + +typedef void +(MLASCALL MLAS_COMPUTE_TANH_FP16_KERNEL)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + typedef float (MLASCALL MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL)( @@ -1119,6 +1144,7 @@ extern "C" { MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; + #if defined(MLAS_TARGET_AMD64) MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; @@ -1489,6 +1515,11 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif + +MLAS_COMPUTE_ERF_FP16_KERNEL* ErfFP16KernelRoutine = nullptr; +MLAS_COMPUTE_GELU_FP16_KERNEL* GeluFP16KernelRoutine = nullptr; +MLAS_COMPUTE_TANH_FP16_KERNEL* TanhFP16KernelRoutine = nullptr; + #if defined(MLAS_TARGET_AMD64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index e7d4bf12aa289..e9f140a2ee0f7 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -19,6 +19,10 @@ Module Name: #ifdef MLAS_USE_SVE #include "sve/mlasi_sve.h" #endif +#if defined(MLAS_NEON_INTRINSICS) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +#include "erf_neon_fp16.h" +#include "gelu_neon_fp16.h" +#endif #if defined(USE_KLEIDIAI) #include "kleidiai/mlasi_kleidiai.h" #endif @@ -658,6 +662,23 @@ Return Value: } #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && !defined(_WIN32) + #if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + this->ErfFP16KernelRoutine = MlasSveErfFP16Kernel; + this->GeluFP16KernelRoutine = MlasSveGeluFP16Kernel; + this->TanhFP16KernelRoutine = MlasSveTanhFP16Kernel; + } + else{ + this->ErfFP16KernelRoutine = MlasNeonErfFP16Kernel; + this->GeluFP16KernelRoutine = MlasNeonGeluFP16Kernel; + } + #else + this->ErfFP16KernelRoutine = MlasNeonErfFP16Kernel; + this->GeluFP16KernelRoutine = MlasNeonGeluFP16Kernel; + #endif +#endif + // // Check if the processor supports ASIMD I8MM instructions. // diff --git a/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp new file mode 100644 index 0000000000000..26f2a09439910 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sve/elementwise_sve_fp16.cpp @@ -0,0 +1,256 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + Elementwise_sve_fp16.cpp + +Abstract: + + This module contains the SVE Elementwise functions . + +--*/ +#include "mlas_sve_fp16.h" + +using _mlas_fp16_ = uint16_t; +struct MlasTanhConstants_fp16_scalar { + __fp16 LowerRange; + __fp16 UpperRange; + __fp16 alpha_7; + __fp16 alpha_5; + __fp16 alpha_3; + __fp16 alpha_1; + __fp16 beta_6; + __fp16 beta_4; + __fp16 beta_2; + __fp16 beta_0; +}; + +constexpr MlasTanhConstants_fp16_scalar TanhConstantsFp16 = { + -3.515625f, + 3.515625f, + 5.960464477539063e-08f, + 1.4841556549072266e-05f, + 0.000637054443359375f, + 0.004894256591796875f, + 1.1920928955078125e-06f, + 0.00011855363845825195f, + 0.0022678375244140625f, + 0.004894256591796875f +}; + +static inline MLAS_SVFLOAT16 +Tanh_Vector_SVE_fp16(MLAS_SVFLOAT16 x, MLAS_SVBOOL pg) +{ + MLAS_SVFLOAT16 g_LowerRange_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.LowerRange); + MLAS_SVFLOAT16 g_UpperRange_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.UpperRange); + MLAS_SVFLOAT16 g_alpha_7_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_7); + MLAS_SVFLOAT16 g_alpha_5_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_5); + MLAS_SVFLOAT16 g_alpha_3_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_3); + MLAS_SVFLOAT16 g_alpha_1_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.alpha_1); + MLAS_SVFLOAT16 g_beta_6_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_6); + MLAS_SVFLOAT16 g_beta_4_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_4); + MLAS_SVFLOAT16 g_beta_2_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_2); + MLAS_SVFLOAT16 g_beta_0_vec = MlasSveBroadcastfloat16(TanhConstantsFp16.beta_0); + + x = MlasSveMinfloat16(pg, x, g_UpperRange_vec); + x = MlasSveMaxfloat16(pg, x, g_LowerRange_vec); + + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, x, x); + MLAS_SVFLOAT16 p = MlasSveMLAfloat16(pg, g_alpha_5_vec, g_alpha_7_vec, x2); + p = MlasSveMLAfloat16(pg, g_alpha_3_vec, p, x2); + p = MlasSveMLAfloat16(pg, g_alpha_1_vec, p, x2); + p = MlasSveMulfloat16(pg, p, x); + + svfloat16_t q = MlasSveMLAfloat16(pg, g_beta_4_vec, g_beta_6_vec, x2); + q = MlasSveMLAfloat16(pg, g_beta_2_vec, q, x2); + q = MlasSveMLAfloat16(pg, g_beta_0_vec, q, x2); + + MLAS_SVFLOAT16 res = MlasSveDivfloat16(pg, p, q); + + return res; +} + +void +MlasSveTanhFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) +{ + size_t offset = 0; + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + while (offset < N) { + MLAS_SVBOOL pg = MlasSveSelPredicatefloat16(offset, N); + MLAS_SVFLOAT16 x = MlasSvereinterpretf16_u16(MlasSveLoadUint16(pg, &input[offset])); + MLAS_SVFLOAT16 y = Tanh_Vector_SVE_fp16(x, pg); + MlasSveStoreUint16(pg, &output[offset], MlasSvereinterpretu16_f16(y)); + offset += svcnth(); + } +} + +static inline MLAS_SVFLOAT16 +exp_neg_rational_approx_f16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x) +{ + const __fp16 a0 = 6.0f; + MLAS_SVFLOAT16 max_x = MlasSveBroadcastfloat16(a0); + x = MlasSveMinfloat16(pg, x, max_x); + + const __fp16 c0 = 1.330f; + const __fp16 c1 = -0.390f; + const __fp16 c2 = 0.0288f; + + const __fp16 d0 = 1.338f; + const __fp16 d1 = 0.848f; + const __fp16 d2 = 0.467f; + + MLAS_SVFLOAT16 c0v = MlasSveBroadcastfloat16(c0); + MLAS_SVFLOAT16 c1v = MlasSveBroadcastfloat16(c1); + MLAS_SVFLOAT16 c2v = MlasSveBroadcastfloat16(c2); + MLAS_SVFLOAT16 d0v = MlasSveBroadcastfloat16(d0); + MLAS_SVFLOAT16 d1v = MlasSveBroadcastfloat16(d1); + MLAS_SVFLOAT16 d2v = MlasSveBroadcastfloat16(d2); + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, x, x); + + MLAS_SVFLOAT16 num = MlasSveMLAfloat16(pg, c0v, c1v, x); + num = MlasSveMLAfloat16(pg, num, c2v, x2); + + MLAS_SVFLOAT16 den = MlasSveMLAfloat16(pg, d0v, d1v, x); + den = MlasSveMLAfloat16(pg, den, d2v, x2); + + MLAS_SVFLOAT16 recip = MlasSveReciprocalfloat16(den); + recip = MlasSveMulfloat16(pg, recip, MlasSveReciprocalStepfloat16(den, recip)); + recip = MlasSveMulfloat16(pg, recip, MlasSveReciprocalStepfloat16(den, recip)); + + MLAS_SVFLOAT16 result = MlasSveMulfloat16(pg, num, recip); + return result; +} + +void MLASCALL +MlasSveErfFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) +{ + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + const __fp16 p = 0.328f; + const __fp16 a1 = 0.2505f; + const __fp16 a2 = -0.2881f; + const __fp16 a3 = 1.4102f; + const __fp16 a4 = -1.423f; + const __fp16 a5 = 1.0547f; + + MLAS_SVFLOAT16 vp = MlasSveBroadcastfloat16(p); + MLAS_SVFLOAT16 va1 = MlasSveBroadcastfloat16(a1); + MLAS_SVFLOAT16 va2 = MlasSveBroadcastfloat16(a2); + MLAS_SVFLOAT16 va3 = MlasSveBroadcastfloat16(a3); + MLAS_SVFLOAT16 va4 = MlasSveBroadcastfloat16(a4); + MLAS_SVFLOAT16 va5 = MlasSveBroadcastfloat16(a5); + + const __fp16 v1 = 1.0f; + const __fp16 v2 = -1.0f; + const __fp16 v3 = 0.0f; + const __fp16 v4 = 4.0f; + MLAS_SVFLOAT16 vone = MlasSveBroadcastfloat16(v1); + MLAS_SVFLOAT16 vneg_one = MlasSveBroadcastfloat16(v2); + MLAS_SVFLOAT16 vzero = MlasSveBroadcastfloat16(v3); + MLAS_SVFLOAT16 vth = MlasSveBroadcastfloat16(v4); + + size_t i = 0; + while (i < N) { + MLAS_SVBOOL pg = MlasSveSelPredicatefloat16(i, N); + MLAS_SVFLOAT16 x = MlasSvereinterpretf16_u16(MlasSveLoadUint16(pg, &input[i])); + MLAS_SVBOOL neg_mask = MlasSveComparelessthanfloat16(pg, x, vzero); + MLAS_SVFLOAT16 sign = MlasSveSelectfloat16(neg_mask, vneg_one, vone); + MLAS_SVFLOAT16 absx = MlasSveAbsolutefloat16(MlasSveBroadcastfloat16(v3), pg, x); + svbool_t use_mask = MlasSveComparelessthanfloat16(pg, absx, vth); + MLAS_SVFLOAT16 absx_clamped = MlasSveMinfloat16(pg, absx, vth); + MLAS_SVFLOAT16 denom = MlasSveMLAfloat16(pg, vone, vp, absx_clamped); + MLAS_SVFLOAT16 t = MlasSveReciprocalfloat16(denom); + t = MlasSveMulfloat16(pg, t, MlasSveReciprocalStepfloat16(denom, t)); + t = MlasSveMulfloat16(pg, t, MlasSveReciprocalStepfloat16(denom, t)); + MLAS_SVFLOAT16 t2 = MlasSveMulfloat16(pg, t, t); + MLAS_SVFLOAT16 t3 = MlasSveMulfloat16(pg, t2, t); + MLAS_SVFLOAT16 t4 = MlasSveMulfloat16(pg, t3, t); + MLAS_SVFLOAT16 t5 = MlasSveMulfloat16(pg, t4, t); + svfloat16_t poly = MlasSveMulfloat16(pg, va1, t); + poly = MlasSveMLAfloat16(pg, poly, va2, t2); + poly = MlasSveMLAfloat16(pg, poly, va3, t3); + poly = MlasSveMLAfloat16(pg, poly, va4, t4); + poly = MlasSveMLAfloat16(pg, poly, va5, t5); + MLAS_SVFLOAT16 x2 = MlasSveMulfloat16(pg, absx_clamped, absx_clamped); + MLAS_SVFLOAT16 exp_neg_x2 = exp_neg_rational_approx_f16(pg, x2); + MLAS_SVFLOAT16 poly_mul_exp = MlasSveMulfloat16(pg, poly, exp_neg_x2); + MLAS_SVFLOAT16 one_minus_term = MlasSveSubtractfloat16(pg, vone, poly_mul_exp); + MLAS_SVFLOAT16 erf_approx = MlasSveMulfloat16(pg, sign, one_minus_term); + erf_approx = MlasSveMinfloat16(pg, erf_approx, vone); + erf_approx = MlasSveMaxfloat16(pg, erf_approx, vneg_one); + MLAS_SVFLOAT16 result = MlasSveSelectfloat16(use_mask, erf_approx, sign); + MlasSveStoreUint16(pg, &output[i], MlasSvereinterpretu16_f16(result)); + i += svcntp_b16(svptrue_b16(), pg); + } +} + +void MLASCALL +MlasSveGeluFP16Kernel(const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, size_t count, MLAS_GELU_ALGORITHM algo) +{ + const __fp16 r1 = 0.5f; + const __fp16 r2 = 1.0f; + const __fp16 r3 = static_cast(M_SQRT1_2); + const __fp16 r4 = 0.7978845608028654f; + const __fp16 r5 = 0.035677408136300125f; + + const MLAS_SVFLOAT16 v_half = MlasSveBroadcastfloat16(r1); + const MLAS_SVFLOAT16 v_one = MlasSveBroadcastfloat16(r2); + const MLAS_SVFLOAT16 v_sqrt1_2 = MlasSveBroadcastfloat16(r3); + const MLAS_SVFLOAT16 v_B = MlasSveBroadcastfloat16(r4); + const MLAS_SVFLOAT16 v_C = MlasSveBroadcastfloat16(r5); + + const __fp16 c1 = -5.0f; + const __fp16 c2 = 5.0f; + if (algo == MlasGeluTanh) { + size_t i = 0; + while (i < count) { + svbool_t pg = MlasSveSelPredicatefloat16(i, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]); + MLAS_SVFLOAT16 v_x2 = MlasSveMulfloat16(pg, v_x, v_x); + MLAS_SVFLOAT16 v_inner = MlasSveMLAfloat16(pg, v_B, v_C, v_x2); + MLAS_SVFLOAT16 v_tanh_arg = MlasSveMulfloat16(pg, v_x, v_inner); + v_tanh_arg = MlasSveMaxfloat16(pg, MlasSveBroadcastfloat16(c1), MlasSveMinfloat16(pg, v_tanh_arg, MlasSveBroadcastfloat16(c2))); + MlasSveStoreF16(pg, &temp[i], v_tanh_arg); + i += svcnth(); + } + + MlasSveTanhFP16Kernel(reinterpret_cast(temp), reinterpret_cast(temp), count); + + size_t j = 0; + while (j < (count)) { + svbool_t pg = MlasSveSelPredicatefloat16(j, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[j]); + MLAS_SVFLOAT16 v_tanh = MlasSveLoadFloat16(pg, &temp[j]); + MLAS_SVFLOAT16 v_result = MlasSveMulfloat16(pg, v_half, MlasSveMulfloat16(pg, v_x, svadd_f16_m(pg, v_one, v_tanh))); + MlasSveStoreF16(pg, &output[j], v_result); + j += svcnth(); + } + } else { + size_t i = 0; + while (i < (count)) { + svbool_t pg = MlasSveSelPredicatefloat16(i, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[i]); + MLAS_SVFLOAT16 v_scaled = MlasSveMulfloat16(pg, v_x, v_sqrt1_2); + MlasSveStoreF16(pg, &temp[i], v_scaled); + i += svcnth(); + } + + MlasSveErfFP16Kernel(temp, temp, count); + + size_t j = 0; + while (j < (count)) { + svbool_t pg = MlasSveSelPredicatefloat16(j, count); + MLAS_SVFLOAT16 v_x = MlasSveLoadFloat16(pg, &input[j]); + MLAS_SVFLOAT16 v_erf = MlasSveLoadFloat16(pg, &temp[j]); + MLAS_SVFLOAT16 v_result = MlasSveMulfloat16(pg, v_half, MlasSveMulfloat16(pg, v_x, MlasSveAddfloat16(pg, v_one, v_erf))); + MlasSveStoreF16(pg, &output[j], v_result); + j += svcnth(); + } + } +} diff --git a/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h b/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h new file mode 100644 index 0000000000000..cde88c5517873 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sve/mlas_sve_fp16.h @@ -0,0 +1,181 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_sve_fp16.h + +Abstract: + + This module contains the procedure prototypes for the SVE FP16 intrinsics. + +--*/ + +#pragma once +#include "mlasi_sve.h" + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveBroadcastfloat16(__fp16 Value) +{ + return svdup_f16(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMinfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 range) +{ + return svmin_f16_m(pg, x, range); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMaxfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 range) +{ + return svmax_f16_m(pg, x, range); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMulfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svmul_f16_m(pg, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveMLAfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y, MLAS_SVFLOAT16 z) +{ + return svmla_f16_m(pg, x, y, z); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveDivfloat16(MLAS_SVBOOL pg, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svdiv_f16_m(pg, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveSelPredicatefloat16(size_t x, size_t y) +{ + return svwhilelt_b16(x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSvereinterpretf16_u16(MLAS_SVUINT16 x) +{ + return svreinterpret_f16_u16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT16 +MlasSveLoadUint16(MLAS_SVBOOL pg, const uint16_t* x) +{ + return svld1_u16(pg, x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveLoadFloat16(MLAS_SVBOOL pg, const MLAS_FP16* x) +{ + return svld1_f16(pg, reinterpret_cast(x)); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreUint16(MLAS_SVBOOL pg, uint16_t* Buffer, MLAS_SVUINT16 Vector) +{ + return svst1_u16(pg, Buffer, Vector); +} +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreF16(MLAS_SVBOOL pg, MLAS_FP16* Buffer, MLAS_SVFLOAT16 Vector) +{ + return svst1_f16(pg, reinterpret_cast<__fp16*>(Buffer), Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT16 +MlasSvereinterpretu16_f16(MLAS_SVFLOAT16 x) +{ + return svreinterpret_u16_f16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveReciprocalfloat16(MLAS_SVFLOAT16 x) +{ + return svrecpe_f16(x); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveReciprocalStepfloat16(MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svrecps_f16(x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveSelectfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svsel_f16(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveSubtractfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svsub_f16_m(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveComparelessthanfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svcmplt_f16(Pred, x, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveAbsolutefloat16(MLAS_SVFLOAT16 inactive, MLAS_SVBOOL Pred, MLAS_SVFLOAT16 y) +{ + return svabs_f16_m(inactive, Pred, y); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT16 +MlasSveAddfloat16(MLAS_SVBOOL Pred, MLAS_SVFLOAT16 x, MLAS_SVFLOAT16 y) +{ + return svadd_f16_m(Pred, x, y); +} +#endif diff --git a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h index 67a4bf453dd05..922945c702119 100644 --- a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h +++ b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h @@ -32,8 +32,35 @@ typedef svfloat32_t MLAS_SVFLOAT32; typedef svint32_t MLAS_SVINT32; typedef svuint32_t MLAS_SVUINT32; typedef svbool_t MLAS_SVBOOL; +typedef svfloat16_t MLAS_SVFLOAT16; +typedef svuint16_t MLAS_SVUINT16; -// function decarations +void +MLASCALL +MlasSveErfFP16Kernel( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + +void +MLASCALL +MlasSveTanhFP16Kernel( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +); + +void +MLASCALL +MlasSveGeluFP16Kernel( + const MLAS_FP16* Input, + MLAS_FP16* Output, + MLAS_FP16* Temp, + size_t N, + MLAS_GELU_ALGORITHM Algo +); +// function declarations MLAS_FORCEINLINE MLAS_SVFLOAT32 MlasSveComputeExpVector( diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index 63bd744795535..adb1d7170bd5c 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -193,6 +193,10 @@ MlasComputeTanh( MLAS_FP16* Output, size_t N ) { + if(GetMlasPlatform().TanhFP16KernelRoutine){ + GetMlasPlatform().TanhFP16KernelRoutine(Input, Output, N); + return; + } const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; if (dispatch == nullptr || dispatch->Tanh_Fp16 == nullptr) { MLAS_THROW_EX(std::runtime_error, "Tanh_Fp16 is not supported."); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 9ded95753cd8e..fb854b19accfa 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1251,7 +1251,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, Gelu); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); +#endif #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); @@ -3337,7 +3340,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { IsNaN)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, @@ -3752,7 +3756,8 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - }; + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index cadb270d0c17e..4ddb5c7e78037 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -2023,7 +2023,7 @@ Status Erf::Compute(OpKernelContext* context) const { const float* p_input = input_data + start; float* p_output = output_data + start; const std::ptrdiff_t count = std::min(length_per_task, elem_count - start); - MlasComputeErf(p_input, p_output, count); + MlasComputeErf(p_input, p_output, static_cast(count)); }, 0); @@ -2041,7 +2041,6 @@ Status Erf::Compute(OpKernelContext* context) const { int64_t elem_count = X->Shape().Size(); constexpr int64_t length_per_task = 4096; int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; - const auto narrow_task_count = onnxruntime::narrow(task_count); // get allocator for temporary buffers @@ -2053,19 +2052,12 @@ Status Erf::Compute(OpKernelContext* context) const { [&](ptrdiff_t task_idx) { const auto start = task_idx * length_per_task; const int64_t count = std::min(length_per_task, elem_count - start); - const auto narrow_count = onnxruntime::narrow(count); - const MLFloat16* p_input = input_data + start; MLFloat16* p_output = output_data + start; - - // allocate temp buffers using ORT allocator - IAllocatorUniquePtr input_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); - IAllocatorUniquePtr output_fp32 = IAllocator::MakeUniquePtr(alloc, narrow_count); - - // convert, compute, convert back - MlasConvertHalfToFloatBuffer(p_input, input_fp32.get(), narrow_count); - MlasComputeErf(input_fp32.get(), output_fp32.get(), narrow_count); - MlasConvertFloatToHalfBuffer(output_fp32.get(), p_output, narrow_count); + // allocate temp buffers for fp32 input and output + IAllocatorUniquePtr input_tmp_fp32 = IAllocator::MakeUniquePtr(alloc, static_cast(count)); + IAllocatorUniquePtr output_tmp_fp32 = IAllocator::MakeUniquePtr(alloc, static_cast(count)); + MlasComputeFP16Erf(p_input, p_output, input_tmp_fp32.get(), output_tmp_fp32.get(), static_cast(count)); }, 0); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index e34af83d1f29e..2985469f35bf9 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -12,6 +12,9 @@ #include "core/providers/cpu/element_wise_ranged_transform.h" #include "core/providers/cpu/tensor/gelu.h" +#include +#include + using onnxruntime::narrow; using namespace onnxruntime::common; @@ -19,11 +22,17 @@ namespace onnxruntime { // May revisit the implementations to support inplace computation, if needed. -ONNX_CPU_OPERATOR_KERNEL( - Gelu, - 20, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gelu); +#define ADD_TYPED_GELU_OP(data_type) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + Gelu, \ + 20, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gelu) + +ADD_TYPED_GELU_OP(float); +ADD_TYPED_GELU_OP(MLFloat16); #ifndef DISABLE_CONTRIB_OPS namespace contrib { @@ -98,4 +107,75 @@ Status Gelu::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); } -} // namespace onnxruntime +template <> +Status Gelu::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const MLFloat16* input_data = input->Data(); + Tensor* output = context->Output(0, input->Shape()); + MLFloat16* output_data = output->MutableData(); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + + int64_t elem_count = input->Shape().Size(); + constexpr int64_t length_per_task = 4096; + int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + + MLAS_GELU_ALGORITHM algo; + if (approximation_algorithm_ == "tanh") { + algo = MlasGeluTanh; + } else if (approximation_algorithm_ == "none") { + algo = MlasGeluErf; + } else { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported approximation_algorithm: ", + approximation_algorithm_); + } + + if (elem_count == 0) { + return Status::OK(); + } + + // Allocate scratch buffer using ORT temp-space allocator + size_t buffer_size = static_cast(elem_count) * sizeof(MLFloat16); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + void* raw = allocator->Alloc(buffer_size); + if (!raw) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to allocate temporary buffer."); + } + + auto deleter = [allocator](MLFloat16* p) { + if (p) allocator->Free(p); + }; + + std::unique_ptr temp_fp16( + static_cast(raw), deleter); + + concurrency::ThreadPool::TryBatchParallelFor( + tp, + static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const MLFloat16* p_input = input_data + start; + MLFloat16* p_output = output_data + start; + + int64_t count = std::min(length_per_task, elem_count - start); + + MLFloat16* p_temp = temp_fp16.get() + start; + + MlasComputeFP16Gelu( + p_input, + p_output, + p_temp, + narrow(count), + algo); + }, + 0); + + return Status::OK(); +} + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d711e050fb913..b0a4ce3ed6599 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -752,6 +752,58 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { {}, {{"approximate", "tanh"}}, true, 20); } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +TEST_F(ActivationOpTest, Gelu_fp16_tanh) { + OpTester test("Gelu", 20); + auto formula = [](float x) { + return 0.5f * x * (1 + tanhf(0.7978845608028654f * (x + 0.044715f * x * x * x))); + }; + const std::vector X = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + std::vector Y; + Y.reserve(X.size()); + for (float x : X) { + Y.push_back(formula(x)); + } + std::vector dims{static_cast(X.size())}; + + std::vector f_X(X.size()); + std::vector f_Y(Y.size()); + ConvertFloatToMLFloat16(X.data(), f_X.data(), static_cast(X.size())); + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), static_cast(Y.size())); + + test.AddInput("X", dims, f_X); + test.AddOutput("Y", dims, f_Y); + test.AddAttribute("approximate", "tanh"); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST_F(ActivationOpTest, Gelu_fp16_erf) { + OpTester test("Gelu", 20); + auto formula = [](float x) { + return static_cast(0.5 * x * (1 + erf(x * M_SQRT1_2))); + }; + const std::vector X = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + std::vector Y; + Y.reserve(X.size()); + for (float x : X) { + Y.push_back(formula(x)); + } + std::vector dims{static_cast(X.size())}; + + std::vector f_X(X.size()); + std::vector f_Y(Y.size()); + ConvertFloatToMLFloat16(X.data(), f_X.data(), static_cast(X.size())); + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), static_cast(Y.size())); + + test.AddInput("X", dims, f_X); + test.AddOutput("Y", dims, f_Y); + test.AddAttribute("approximate", "none"); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} +#endif #endif } // namespace test