-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[MLAS] Enable FP16 for Gelu #26815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[MLAS] Enable FP16 for Gelu #26815
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
d03e5ca
Enable Gelu Fp16
akote123 d2ac54d
Resolve Review Comments
akote123 aa05b43
Resolved Copilot comments
f2b8a98
Resolved CI failures
c3c00f1
Fixed formatting errors
c0ffef2
Added runtime guards and resolved CIfailures
c4ebd5c
Resolved MacOS and Web CI failures
ce67e4f
resolved latest CI failures
0b189cc
Resolved latest comments
421188a
Resolved Windows CI failures
48f8697
Fix FP16 GELU SVE linkage and resolve function pointer type mismatch
0e0af15
Resolved latest comments
b7e9d92
removed old header file
6ec0002
removed unnecessary compile flags
06b8be6
Resolved new comments/CI
ab04489
Resolved merge conflicts
482cdf0
Resolved merge conflict
7ea64c6
Incorporated copilot comments and testing compile guard
5b241f8
Addressed the latest review comments
adf55a8
Removed unnecessary headers
5d1cef0
Incorporated existing memory allocation method
9a46921
Removed unnecessary header
dca3c38
Update onnxruntime/core/mlas/lib/erf.cpp
akote123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| /*++ | ||
|
|
||
| Copyright 2025 FUJITSU LIMITED | ||
| Copyright (c) Microsoft Corporation. All rights reserved. | ||
|
|
||
| Licensed under the MIT License. | ||
|
|
||
| Module Name: | ||
|
|
||
| erf_neon_fp16.cpp | ||
|
|
||
| Abstract: | ||
|
|
||
| This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. | ||
|
|
||
| --*/ | ||
|
|
||
| #include "erf_neon_fp16.h" | ||
|
|
||
| #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) | ||
|
|
||
| using _mlas_fp16_ = uint16_t; | ||
| // Helpers to safely convert between float and FP16-bit representation | ||
| static float | ||
| fp16_to_float(uint16_t h) | ||
| { | ||
| __fp16 tmp; | ||
|
hariharans29 marked this conversation as resolved.
|
||
| std::memcpy(&tmp, &h, sizeof(h)); | ||
| return (float)tmp; | ||
|
akote123 marked this conversation as resolved.
|
||
| } | ||
|
|
||
| static uint16_t | ||
| float_to_fp16(float f) | ||
| { | ||
| __fp16 tmp = (__fp16)f; | ||
| uint16_t h; | ||
| std::memcpy(&h, &tmp, sizeof(h)); | ||
| return h; | ||
| } | ||
|
|
||
| static inline MLAS_FLOAT16X8 | ||
| exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x) | ||
| { | ||
| const float16_t a0 = 6.0f; | ||
| MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0); | ||
| x = MlasMinimumFloat16(x, max_x); | ||
|
|
||
| const float16_t c0 = 1.330f; | ||
| const float16_t c1 = -0.390f; | ||
| const float16_t c2 = 0.0288f; | ||
|
|
||
| const float16_t d0 = 1.338f; | ||
| const float16_t d1 = 0.848f; | ||
| const float16_t d2 = 0.467f; | ||
|
|
||
| MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0); | ||
| MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1); | ||
| MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2); | ||
|
|
||
| MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0); | ||
| MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1); | ||
| MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2); | ||
| MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); | ||
| MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v); | ||
| num = MlasMultiplyAddFloat16(c2v, x2, num); | ||
| MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v); | ||
| den = MlasMultiplyAddFloat16(d2v, x2, den); | ||
| MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den); | ||
| recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); | ||
| recip = MlasMultiplyFloat16(recip, MlasReciprocalStepFloat16(den, recip)); | ||
| MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip); | ||
| return result; | ||
| } | ||
|
|
||
| void | ||
| MlasNeonErfFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) | ||
| { | ||
| const auto* input = reinterpret_cast<const _mlas_fp16_*>(Input); | ||
| auto* output = reinterpret_cast<_mlas_fp16_*>(Output); | ||
| const float16_t p = 0.328f; | ||
| const float16_t a1 = 0.2505f; | ||
| const float16_t a2 = -0.2881f; | ||
| const float16_t a3 = 1.4102f; | ||
| const float16_t a4 = -1.423f; | ||
| const float16_t a5 = 1.0547f; | ||
|
|
||
| MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p); | ||
| MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1); | ||
| MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2); | ||
| MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3); | ||
| MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4); | ||
| MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5); | ||
|
|
||
| constexpr float16_t one_fp16 = 1.0f; | ||
| constexpr float16_t neg_one_fp16 = -1.0f; | ||
| constexpr float16_t zero_fp16 = 0.0f; | ||
| constexpr float16_t four_fp16 = 4.0f; | ||
|
|
||
| MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16); | ||
| MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16); | ||
| MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16); | ||
| MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16); | ||
|
|
||
| size_t i = 0; | ||
| for (; i + 8 <= N; i += 8) { | ||
| MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&input[i]); | ||
| MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero); | ||
| MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone); | ||
| MLAS_FLOAT16X8 absx = MlasAbsFloat16(x); | ||
| MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth); | ||
| MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth); | ||
| MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone); | ||
| MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom); | ||
| t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); | ||
| t = MlasMultiplyFloat16(t, MlasReciprocalStepFloat16(denom, t)); | ||
| MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t); | ||
| MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t); | ||
| MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t); | ||
| MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t); | ||
| MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t); | ||
| poly = MlasMultiplyAddFloat16(va2, t2, poly); | ||
| poly = MlasMultiplyAddFloat16(va3, t3, poly); | ||
| poly = MlasMultiplyAddFloat16(va4, t4, poly); | ||
| poly = MlasMultiplyAddFloat16(va5, t5, poly); | ||
| MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped); | ||
| MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2); | ||
| MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2); | ||
| MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp); | ||
| MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term); | ||
| erf_approx = MlasMinimumFloat16(erf_approx, vone); | ||
| erf_approx = MlasMaximumFloat16(erf_approx, vneg_one); | ||
| MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign); | ||
| MlasStoreFloat16x8(&output[i], result); | ||
| } | ||
|
|
||
| for (; i < N; i++) { | ||
| float x = fp16_to_float(input[i]); | ||
| float sign = (x < 0) ? -1.0f : 1.0f; | ||
| float absx = fabsf(x); | ||
|
|
||
| if (absx > 4.0f) { | ||
| output[i] = float_to_fp16(sign); | ||
| continue; | ||
| } | ||
|
|
||
| float t = 1.0f / (1.0f + p * absx); | ||
| float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t; | ||
| float exp_neg_x2 = expf(-absx * absx); | ||
| float erf_approx = sign * (1.0f - poly * exp_neg_x2); | ||
| if (erf_approx > 1.0f) erf_approx = 1.0f; | ||
| if (erf_approx < -1.0f) erf_approx = -1.0f; | ||
|
|
||
| output[i] = float_to_fp16(erf_approx); | ||
| } | ||
| } | ||
| #endif | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| /*++ | ||
|
|
||
| Copyright 2025 FUJITSU LIMITED | ||
| Copyright (c) Microsoft Corporation. All rights reserved. | ||
|
|
||
| Licensed under the MIT License. | ||
|
|
||
| Module Name: | ||
|
|
||
| erf_neon_fp16.h | ||
|
|
||
| Abstract: | ||
|
|
||
| This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. | ||
|
|
||
| --*/ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <arm_neon.h> | ||
|
|
||
| #include "mlasi.h" | ||
| #include "fp16_common.h" | ||
| #include "softmax_kernel_neon.h" | ||
| #include <cstring> | ||
|
|
||
| void MlasNeonErfFP16Kernel(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.