Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,10 @@ else()
if (onnxruntime_USE_SVE)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlas_sve_fp16.h)
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
list(APPEND mlas_private_compile_definitions MLAS_USE_SVE)
endif()

Expand Down Expand Up @@ -548,6 +551,10 @@ else()
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/erf_neon_fp16.h
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
${MLAS_SRC_DIR}/gelu_neon_fp16.h
${MLAS_SRC_DIR}/gelu_neon_fp16.cpp
)
if (onnxruntime_USE_ARM_NEON_NCHWC)
list(APPEND mlas_platform_srcs
Expand Down Expand Up @@ -577,6 +584,8 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down Expand Up @@ -988,4 +997,4 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD)
endif()
endif()

endif()
endif()
69 changes: 69 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -2237,3 +2237,72 @@ MlasFlashAttention(
MlasFlashAttentionThreadedArgs* args,
MLAS_THREADPOOL* ThreadPool
);

/**
* @brief Enumeration of supported GELU algorithm variants.
*
* MlasGeluErf - Exact GELU implementation using the error function (erf).
* MlasGeluTanh - Approximate GELU implementation using tanh-based formulation.
*/
typedef enum MLAS_GELU_ALGORITHM {
MlasGeluErf = 0,
MlasGeluTanh = 1
} MLAS_GELU_ALGORITHM;

/**
* @brief Computes element-wise FP16 error function (erf).
*
* This routine computes:
* Output[i] = erf(Input[i])
* for N elements. Depending on platform capabilities, this may use
* vectorized FP16 intrinsics or fall back to a scalar FP32 conversion path.
*
* @param Input Pointer to input buffer of N FP16 elements.
* @param Output Pointer to output buffer of N FP16 elements.
* @param Input_tmp_fp32 Pointer to caller-allocated scratch buffer of N floats
* for FP32 input conversion (used only on fallback path).
* @param Output_tmp_fp32 Pointer to caller-allocated scratch buffer of N floats
* for FP32 output conversion (used only on fallback path).
* @param N Number of elements to process.
*/
void
Comment thread
hariharans29 marked this conversation as resolved.
MLASCALL
MlasComputeFP16Erf(
const MLAS_FP16* Input,
MLAS_FP16* Output,
float* Input_tmp_fp32,
float* Output_tmp_fp32,
size_t N
);

/**
* @brief Computes element-wise FP16 GELU activation.
*
* This routine computes:
*
* If algo == MlasGeluTanh (approximate):
* GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
*
* If algo == MlasGeluErf (exact):
* GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
*
* Depending on platform capabilities, this may use vectorized FP16 kernels
* (SVE/NEON) or fall back to a scalar FP32 conversion path.
*
* @param input Pointer to input buffer of FP16 elements.
* @param output Pointer to output buffer of FP16 elements.
* @param temp Temporary scratch buffer of at least 'count' FP16 elements.
* Required by certain vectorized implementations. May be unused
* in scalar fallback paths.
* @param count Number of elements to process.
* @param algo GELU algorithm variant (exact erf or tanh approximation).
*/
void
MLASCALL
MlasComputeFP16Gelu(
const MLAS_FP16* input,
MLAS_FP16* output,
MLAS_FP16* temp,
size_t count,
MLAS_GELU_ALGORITHM algo
);
20 changes: 20 additions & 0 deletions onnxruntime/core/mlas/lib/erf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,23 @@ Return Value:
MlasErfKernel(Input, Output, N);
#endif
}

void
MLASCALL
MlasComputeFP16Erf(
const MLAS_FP16* Input,
MLAS_FP16* Output,
float* Input_tmp_fp32,
float* Output_tmp_fp32,
size_t N
)
{
if(GetMlasPlatform().ErfFP16KernelRoutine){
GetMlasPlatform().ErfFP16KernelRoutine(Input, Output, N);
return;
}

MlasConvertHalfToFloatBuffer(Input, Input_tmp_fp32, N);
MlasComputeErf(Input_tmp_fp32, Output_tmp_fp32, N);
MlasConvertFloatToHalfBuffer(Output_tmp_fp32, Output, N);
}
156 changes: 156 additions & 0 deletions onnxruntime/core/mlas/lib/erf_neon_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*++

Copyright 2025 FUJITSU LIMITED
Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

erf_neon_fp16.cpp

Abstract:

This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.

--*/

#include "erf_neon_fp16.h"

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)

using _mlas_fp16_ = uint16_t;
// Helpers to safely convert between float and FP16-bit representation
static float
fp16_to_float(uint16_t h)
{
__fp16 tmp;
Comment thread
hariharans29 marked this conversation as resolved.
std::memcpy(&tmp, &h, sizeof(h));
return (float)tmp;
Comment thread
akote123 marked this conversation as resolved.
}

static uint16_t
float_to_fp16(float f)
{
__fp16 tmp = (__fp16)f;
uint16_t h;
std::memcpy(&h, &tmp, sizeof(h));
return h;
}

static inline MLAS_FLOAT16X8
exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x)
{
const float16_t a0 = 6.0f;
MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0);
x = MlasMinimumFloat16(x, max_x);

const float16_t c0 = 1.330f;
const float16_t c1 = -0.390f;
const float16_t c2 = 0.0288f;

const float16_t d0 = 1.338f;
const float16_t d1 = 0.848f;
const float16_t d2 = 0.467f;

MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0);
MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1);
MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2);

MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0);
MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1);
MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2);
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x);
MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v);
num = MlasMultiplyAddFloat16(c2v, x2, num);
MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v);
den = MlasMultiplyAddFloat16(d2v, x2, den);
MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den);
recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip));
recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip));
MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip);
return result;
}

void
MlasNeonErfFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N)
{
const auto* input = reinterpret_cast<const _mlas_fp16_*>(Input);
auto* output = reinterpret_cast<_mlas_fp16_*>(Output);
const float16_t p = 0.328f;
const float16_t a1 = 0.2505f;
const float16_t a2 = -0.2881f;
const float16_t a3 = 1.4102f;
const float16_t a4 = -1.423f;
const float16_t a5 = 1.0547f;

MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p);
MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1);
MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2);
MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3);
MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4);
MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5);

constexpr float16_t one_fp16 = 1.0f;
constexpr float16_t neg_one_fp16 = -1.0f;
constexpr float16_t zero_fp16 = 0.0f;
constexpr float16_t four_fp16 = 4.0f;

MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16);
MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16);
MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16);
MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16);

size_t i = 0;
for (; i + 8 <= N; i += 8) {
MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&input[i]);
MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero);
MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone);
MLAS_FLOAT16X8 absx = MlasAbsFloat16(x);
MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth);
MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth);
MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone);
MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom);
t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t));
t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t));
MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t);
MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t);
MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t);
MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t);
MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t);
poly = MlasMultiplyAddFloat16(va2, t2, poly);
poly = MlasMultiplyAddFloat16(va3, t3, poly);
poly = MlasMultiplyAddFloat16(va4, t4, poly);
poly = MlasMultiplyAddFloat16(va5, t5, poly);
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped);
MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2);
MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2);
MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp);
MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term);
erf_approx = MlasMinimumFloat16(erf_approx, vone);
erf_approx = MlasMaximumFloat16(erf_approx, vneg_one);
MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign);
MlasStoreFloat16x8(&output[i], result);
}

for (; i < N; i++) {
float x = fp16_to_float(input[i]);
float sign = (x < 0) ? -1.0f : 1.0f;
float absx = fabsf(x);

if (absx > 4.0f) {
output[i] = float_to_fp16(sign);
continue;
}

float t = 1.0f / (1.0f + p * absx);
float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t;
float exp_neg_x2 = expf(-absx * absx);
float erf_approx = sign * (1.0f - poly * exp_neg_x2);
if (erf_approx > 1.0f) erf_approx = 1.0f;
if (erf_approx < -1.0f) erf_approx = -1.0f;

output[i] = float_to_fp16(erf_approx);
}
}
#endif
27 changes: 27 additions & 0 deletions onnxruntime/core/mlas/lib/erf_neon_fp16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*++

Copyright 2025 FUJITSU LIMITED
Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

erf_neon_fp16.h

Abstract:

This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.

--*/

#pragma once

#include <arm_neon.h>

#include "mlasi.h"
#include "fp16_common.h"
#include "softmax_kernel_neon.h"
#include <cstring>

void MlasNeonErfFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N);
56 changes: 54 additions & 2 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Module Name:
#include "mlas_float16.h"
#include "mlasi.h"

#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED)

// TODO!! Add intel fp16 implementations

Expand Down Expand Up @@ -579,4 +579,56 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector)
return vshl_n_s16(Vector, ShiftCount);
}

#endif // fp16 vector intrinsic supported
// NEON FP16 vector intrinsics
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
Comment thread
hariharans29 marked this conversation as resolved.
MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasBroadcastF16Float16x8(float16_t Value) { return vdupq_n_f16(Value); }

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasLoadf16Float16x8(const float16_t* Buffer) { return vld1q_f16(Buffer); }

MLAS_FORCEINLINE
void
MlasStoref16Float16x8(float16_t* Buffer, MLAS_FLOAT16X8 Vector)
{
vst1q_f16(Buffer, Vector);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasReciprocalStepFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vrecpsq_f16(Vector1, Vector2);
}
Comment thread
akote123 marked this conversation as resolved.

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasApproximateReciprocalFloat16(MLAS_FLOAT16X8 Vector)
{
return vrecpeq_f16(Vector);
}
Comment thread
akote123 marked this conversation as resolved.

MLAS_FORCEINLINE
MLAS_UINT16X8
MlasCompareLessThanFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vcltq_f16(Vector1, Vector2);
}
Comment thread
akote123 marked this conversation as resolved.

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasAbsFloat16(MLAS_FLOAT16X8 Vector)
{
return vabsq_f16(Vector);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasSelectFloat16(MLAS_UINT16X8 Vector, MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vbslq_f16(Vector, Vector1, Vector2);
}
Comment thread
akote123 marked this conversation as resolved.
#endif // NEON FP16 vector intrinsics supported
#endif // mlas fp16 intrinsic supported
Loading
Loading