From 101c0f5b805abdc2bf1e5074b98e9faf511aa252 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 16 Mar 2026 20:32:35 -0700 Subject: [PATCH 01/37] Add fused Silu and Gelu kernels for AVX512 --- cmake/onnxruntime_mlas.cmake | 14 +- .../onnxruntime_session_options_config_keys.h | 8 + onnxruntime/contrib_ops/cpu/activations.h | 28 +- onnxruntime/core/mlas/inc/mlas.h | 22 ++ onnxruntime/core/mlas/lib/gelu.cpp | 69 ++++ .../lib/intrinsics/avx512/gelu_avx512f.cpp | 374 ++++++++++++++++++ .../lib/intrinsics/avx512/silu_avx512f.cpp | 157 ++++++++ onnxruntime/core/mlas/lib/mlasi.h | 8 + onnxruntime/core/mlas/lib/platform.cpp | 7 +- onnxruntime/core/mlas/lib/silu.cpp | 47 +++ onnxruntime/core/providers/cpu/tensor/gelu.cc | 12 +- onnxruntime/core/providers/cpu/tensor/gelu.h | 5 + .../test/mlas/bench/bench_transcendental.cpp | 171 ++++++++ .../unittest/test_transcendental_avx512.cpp | 253 ++++++++++++ .../cpu/activation/activation_op_test.cc | 30 ++ 15 files changed, 1179 insertions(+), 26 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/gelu.cpp create mode 100644 onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp create mode 100644 onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp create mode 100644 onnxruntime/core/mlas/lib/silu.cpp create mode 100644 onnxruntime/test/mlas/bench/bench_transcendental.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f1b3b091bbc6e..c166a9e98de7c 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -34,6 +34,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.h ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/silu.cpp + ${MLAS_SRC_DIR}/gelu.cpp ${MLAS_SRC_DIR}/compute.cpp ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp @@ -200,6 +202,14 @@ function(setup_mlas_source_for_windows) ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + set(mlas_platform_srcs_avx512 + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + + set_source_files_properties(${mlas_platform_srcs_avx512} PROPERTIES COMPILE_FLAGS "/arch:AVX512") + target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/dgemm.cpp ${mlas_platform_srcs_avx} @@ -211,7 +221,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${mlas_platform_srcs_avx512} ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp @@ -760,6 +770,8 @@ endif() ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp ) set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index f0a99bc11c8b3..8d1f1c35d4cd3 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -386,6 +386,14 @@ static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm"; // - "1": Disable KleidiAI kernels even if available. static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_kleidiai"; +// Use the minimax erf approximation inside MLAS exact Gelu when available. +// Some platforms may not have this approximation available, in which case this option will have no effect. +// Option values: +// - "0": Use the default exact erf-based Gelu path. [DEFAULT] +// - "1": Use the AVX512 minimax erf approximation for exact Gelu. +static const char* const kOrtSessionOptionsMlasGeluErfUseMinimaxApproximation = + "mlas.gelu_erf_use_minimax_approximation"; + // When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. // Refer to MatMulNBits op schema for more details. // If not provided, default is 4. diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index f00fad809968f..71e0e8561e110 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -78,22 +78,22 @@ class QuickGelu : public OpKernel { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - if (alpha_ != 1.0f) { - // TODO: Consider vectorizing this scalar multiplication. - // It needs exposing a new API in MLAS to take in a scalar - // that will be used in the elementwise multiplication. - // Estimate the cost-benefit tradeoff before proceeding - // with that optimization. - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * alpha_; - } - - MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); - } else { - // SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f - MlasComputeLogistic(p_input, p_output, onnxruntime::narrow(count)); + if (alpha_ == 1.0f) { + MlasComputeSilu(p_input, p_output, onnxruntime::narrow(count)); + return; } + // TODO: Consider vectorizing this scalar multiplication. + // It needs exposing a new API in MLAS to take in a scalar + // that will be used in the elementwise multiplication. + // Estimate the cost-benefit tradeoff before proceeding + // with that optimization. + for (int64_t i = 0; i < count; i++) { + p_output[i] = p_input[i] * alpha_; + } + + MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + MlasEltwiseMul(p_input, p_output, p_output, onnxruntime::narrow(count)); }, 0); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 80817ff87d736..8df19b5cdaaf9 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1105,6 +1105,11 @@ MlasMaximumPool( // Miscellaneous compute routines. // +enum MLAS_GELU_ERF_MODE { + MlasGeluErfModeExact = 0, + MlasGeluErfModeMinimaxApproximation = 1, +}; + void MLASCALL MlasComputeErf( @@ -1113,6 +1118,23 @@ MlasComputeErf( size_t N ); +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N, + MLAS_GELU_ERF_MODE Mode + ); + +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ); + template void MLASCALL diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp new file mode 100644 index 0000000000000..c282d2e261080 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -0,0 +1,69 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu.cpp + +Abstract: + + This module implements routines to compute the exact Gelu function. + +--*/ + +#include + +#include "mlasi.h" + + +void +MLASCALL +MlasGeluKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe, as the computation is not elementwise. + // The caller should guarantee Input and Output do not overlap. + // The current CPU EP kernel where we call this from guarantees that. + for (size_t i = 0; i < N; ++i) { + Output[i] = Input[i] * static_cast(M_SQRT1_2); + } + + MlasComputeErf(Output, Output, N); + + for (size_t i = 0; i < N; ++i) { + Output[i] = 0.5f * Input[i] * (Output[i] + 1.0f); + } +} + +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N, + MLAS_GELU_ERF_MODE Mode + ) +{ +#if !defined(MLAS_TARGET_AMD64) + MLAS_UNREFERENCED_PARAMETER(Mode); +#endif + +#if defined(MLAS_TARGET_AMD64) + if (Mode == MlasGeluErfModeMinimaxApproximation && GetMlasPlatform().GeluErfMinimaxKernelRoutine != nullptr) { + GetMlasPlatform().GeluErfMinimaxKernelRoutine(Input, Output, N); + return; + } +#endif + +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().GeluKernelRoutine(Input, Output, N); +#else + MlasGeluKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp new file mode 100644 index 0000000000000..a9195bd0c9985 --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -0,0 +1,374 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu_avx512f.cpp + +Abstract: + + This module implements routines to compute exact Gelu with AVX512F + intrinsics. + + Idea and code credit for the minimax approximation: OneDNN library + +--*/ + +#include +#include + +#include "mlasi.h" + +namespace { + +struct GeluAvx512Constants { + static constexpr float InvSqrt2 = 0.70710678118654752440f; + static constexpr float Half = 0.5f; + static constexpr float One = 1.0f; + + static constexpr float ErfUpperAbsRange = 3.925f; + static constexpr float ErfSplitBoundary = 0.921875f; + static constexpr float ErfSMALL_P0 = -5.99104969e-4f; + static constexpr float ErfSMALL_P1 = 4.99339588e-3f; + static constexpr float ErfSMALL_P2 = -2.67667342e-2f; + static constexpr float ErfSMALL_P3 = 1.12818025e-1f; + static constexpr float ErfSMALL_P4 = -3.76124859e-1f; + static constexpr float ErfSMALL_P5_Minus_One = 1.28379151e-1f; + static constexpr float ErfBIG_P0 = 1.72948930e-5f; + static constexpr float ErfBIG_P1 = -3.83208680e-4f; + static constexpr float ErfBIG_P2 = 3.88393435e-3f; + static constexpr float ErfBIG_P3 = -2.42545605e-2f; + static constexpr float ErfBIG_P4 = 1.06777847e-1f; + static constexpr float ErfBIG_P5 = 6.34846687e-1f; + static constexpr float ErfBIG_P6_Minus_One = 1.28717512e-1f; + static constexpr float ErfOne = 1.0f; + static constexpr float ExpLowerRange = -88.3762626647949f; + static constexpr float ExpLog2Reciprocal = 1.44269504088896341f; + static constexpr float ExpLog2Hi = -6.93145752e-1f; + static constexpr float ExpLog2Lo = -1.42860677e-6f; + static constexpr float ExpP0 = 1.38319808e-3f; + static constexpr float ExpP1 = 8.37550033e-3f; + static constexpr float ExpP2 = 4.16689515e-2f; + static constexpr float ExpP3 = 1.66664466e-1f; + static constexpr float ExpP4 = 4.99999851e-1f; + static constexpr float ExpP5 = 1.0f; + static constexpr float ExpP6 = 1.0f; + static constexpr float ExpC = 1.25829120e+7f; +}; + +alignas(64) static const uint32_t MlasGeluErfMinimaxTable[6][32] = { + { + 0xa6f2cb94, 0x32827792, 0x3381cc0c, 0x34523d4a, + 0x351ac44d, 0x35f36d88, 0x36ee8229, 0x37b8a3bb, + 0x3867a213, 0x3940033b, 0x3a2a5a1d, 0x3ae35863, + 0x3b7828f2, 0x3c08b14b, 0x3c515ed3, 0xbb503236, + 0xbd8d8e5e, 0xbe8abcd9, 0xbf0c19a2, 0xbeccb328, + 0x3e176ced, 0x3f470d99, 0x3f7abb28, 0x3f800000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + }, + { + 0x3f4c422a, 0x3f4c421f, 0x3f4c4207, 0x3f4c41cb, + 0x3f4c413b, 0x3f4c3fad, 0x3f4c3a2f, 0x3f4c2d40, + 0x3f4c146a, 0x3f4bc341, 0x3f4ad08c, 0x3f48f8cf, + 0x3f45fac7, 0x3f404e07, 0x3f3b980f, 0x3f48dff3, + 0x3f78b21b, 0x3fbb0704, 0x40019c32, 0x3fe536d6, + 0x3f81331e, 0x3e6c8684, 0x3c98f936, 0x00000000, + 0x3f800000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + }, + { + 0xb62173f4, 0x3735e4cf, 0x37f2ff89, 0x388c23be, + 0x3917535c, 0x39ab2ab0, 0x3a60fadb, 0x3af9b960, + 0x3b6e5491, 0x3c0a4ec5, 0x3ca5aa8c, 0x3d2138d9, + 0x3d8737d4, 0x3ddfb660, 0x3e0f27ab, 0x3d94004b, + 0xbe0efdeb, 0xbf1d96c3, 0xbf89db58, 0xbf6d9897, + 0xbef69fb8, 0xbdc4f8a8, 0xbbde6422, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + }, + { + 0xbe081a19, 0xbe084570, 0xbe08639b, 0xbe089837, + 0xbe08f409, 0xbe09ab95, 0xbe0b66d0, 0xbe0e400a, + 0xbe124df8, 0xbe1bde02, 0xbe2f19c9, 0xbe4931bf, + 0xbe685fbc, 0xbe89c95f, 0xbe96cbca, 0xbe8044aa, + 0xbe0550f2, 0x3dcfd6a1, 0x3e94c826, 0x3e79345f, + 0x3decec91, 0x3ca46568, 0x3aa1e00a, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + }, + { + 0xba3d61db, 0x39f097a3, 0x3a5845dc, 0x3ab1fa35, + 0x3b0cefb8, 0x3b653ab6, 0x3bcae527, 0x3c221712, + 0x3c6c5840, 0x3cc0a703, 0x3d1dcc19, 0x3d63656d, + 0x3d955907, 0x3dbf9910, 0x3dd53f69, 0x3db7dcef, + 0x3d639ebe, 0xba6ede48, 0xbd22be69, 0xbd041cf1, + 0xbc64f5ab, 0xbb097a32, 0xb8ebf380, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + }, + { + 0x3cb7d80c, 0x3c9b6050, 0x3c978d11, 0x3c92e850, + 0x3c8d058b, 0x3c848454, 0x3c6cd623, 0x3c4c824b, + 0x3c2a7935, 0x3be0b390, 0x3b0651ac, 0xbb232f53, + 0xbbd42fa0, 0xbc2c5366, 0xbc492c9e, 0xbc2a7aa6, + 0xbbd55d04, 0xba823a76, 0x3b102aa8, 0x3ae25a7e, + 0x3a31f792, 0x38b84375, 0x3689bb5a, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, + } +}; + +MLAS_FORCEINLINE __m512 +MlasLoadMinimaxTable( + const uint32_t* Table + ) +{ + return _mm512_castsi512_ps(_mm512_load_si512(reinterpret_cast(Table))); +} + +MLAS_FORCEINLINE __m512 +MlasGatherMinimaxCoeff( + int Degree, + __m512i Index + ) +{ + const uint32_t* Base = MlasGeluErfMinimaxTable[Degree]; + const __m512 Lo = MlasLoadMinimaxTable(Base); + const __m512 Hi = MlasLoadMinimaxTable(Base + 16); + return _mm512_permutex2var_ps(Lo, Index, Hi); +} + +MLAS_FORCEINLINE __m512 +MlasGeluErfExpVectorAvx512( + __m512 Value + ) +{ + const __m512 ExpLog2Reciprocal = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Reciprocal); + const __m512 ExpLog2Hi = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Hi); + const __m512 ExpLog2Lo = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Lo); + const __m512 ExpP0 = _mm512_set1_ps(GeluAvx512Constants::ExpP0); + const __m512 ExpP1 = _mm512_set1_ps(GeluAvx512Constants::ExpP1); + const __m512 ExpP2 = _mm512_set1_ps(GeluAvx512Constants::ExpP2); + const __m512 ExpP3 = _mm512_set1_ps(GeluAvx512Constants::ExpP3); + const __m512 ExpP4 = _mm512_set1_ps(GeluAvx512Constants::ExpP4); + const __m512 ExpP5 = _mm512_set1_ps(GeluAvx512Constants::ExpP5); + const __m512 ExpP6 = _mm512_set1_ps(GeluAvx512Constants::ExpP6); + const __m512 ExpC = _mm512_set1_ps(GeluAvx512Constants::ExpC); + + __m512 R = _mm512_fmadd_ps(ExpLog2Reciprocal, Value, ExpC); + R = _mm512_sub_ps(R, ExpC); + + __m512 Fx = _mm512_fmadd_ps(R, ExpLog2Hi, Value); + Fx = _mm512_fmadd_ps(R, ExpLog2Lo, Fx); + + __m512 Y = ExpP0; + Y = _mm512_fmadd_ps(Y, Fx, ExpP1); + Y = _mm512_fmadd_ps(Y, Fx, ExpP2); + Y = _mm512_fmadd_ps(Y, Fx, ExpP3); + Y = _mm512_fmadd_ps(Y, Fx, ExpP4); + Y = _mm512_fmadd_ps(Y, Fx, ExpP5); + Y = _mm512_fmadd_ps(Y, Fx, ExpP6); + Y = _mm512_scalef_ps(Y, R); + + return Y; +} + +MLAS_FORCEINLINE __m512 +MlasGeluErfAvx512( + __m512 Value + ) +{ + const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange); + const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary); + const __m512 ErfSmallP0 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P0); + const __m512 ErfSmallP1 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P1); + const __m512 ErfSmallP2 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P2); + const __m512 ErfSmallP3 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P3); + const __m512 ErfSmallP4 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P4); + const __m512 ErfSmallP5MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P5_Minus_One); + const __m512 ErfBigP0 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P0); + const __m512 ErfBigP1 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P1); + const __m512 ErfBigP2 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P2); + const __m512 ErfBigP3 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P3); + const __m512 ErfBigP4 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P4); + const __m512 ErfBigP5 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P5); + const __m512 ErfBigP6MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P6_Minus_One); + const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne); + const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange); + + const __m512 SignMask = _mm512_and_ps(Value, NegZero); + __m512 AbsValue = _mm512_andnot_ps(NegZero, Value); + AbsValue = _mm512_min_ps(ErfUpperAbsRange, AbsValue); + + const __m512 SquareValue = _mm512_mul_ps(AbsValue, AbsValue); + + __m512 SmallResult = ErfSmallP0; + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP1); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP2); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP3); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP4); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP5MinusOne); + SmallResult = _mm512_fmadd_ps(SmallResult, AbsValue, AbsValue); + + const __mmask16 SplitMask = _mm512_cmp_ps_mask(AbsValue, ErfSplitBoundary, _CMP_GT_OQ); + const __m512 BigInput = _mm512_mask_blend_ps(SplitMask, Zero, AbsValue); + + __m512 BigResult = ErfBigP0; + BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP1); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP2); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP3); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP4); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP5); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP6MinusOne); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, BigInput); + + BigResult = _mm512_xor_ps(BigResult, NegZero); + BigResult = _mm512_max_ps(ExpLowerRange, BigResult); + BigResult = _mm512_sub_ps(ErfOne, MlasGeluErfExpVectorAvx512(BigResult)); + + __m512 Result = _mm512_mask_blend_ps(SplitMask, SmallResult, BigResult); + Result = _mm512_or_ps(Result, SignMask); + return Result; +} + +MLAS_FORCEINLINE __m512 +MlasComputeGeluVectorMinimaxAvx512( + __m512 X + ) +{ + const __m512 PositiveInfinity = _mm512_set1_ps(std::numeric_limits::infinity()); + const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); + const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); + const __m512 One = _mm512_set1_ps(1.0f); + const __m512 Half = _mm512_set1_ps(0.5f); + const __m512 NegativeInfinity = _mm512_set1_ps(-std::numeric_limits::infinity()); + const __m512 NegativeZero = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); + const __m512 RightBound = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x40b15ceeu))); + + const __m512i IndexBias = _mm512_set1_epi32(static_cast(0xc21fffff)); + const __m512i OneI = _mm512_set1_epi32(1); + const __m512i TwentyThreeI = _mm512_set1_epi32(23); + const __m512i TwentyFourI = _mm512_set1_epi32(24); + + const __m512 XPositive = _mm512_and_ps(X, PositiveMask); + + __m512i Index = _mm512_castps_si512(XPositive); + Index = _mm512_add_epi32(Index, IndexBias); + Index = _mm512_srai_epi32(Index, 21); + Index = _mm512_max_epi32(Index, OneI); + Index = _mm512_min_epi32(Index, TwentyFourI); + + const __mmask16 GreaterThanRightBoundMask = _mm512_cmp_ps_mask(XPositive, RightBound, _CMP_GT_OQ); + Index = _mm512_mask_blend_epi32(GreaterThanRightBoundMask, Index, TwentyThreeI); + + __m512 Polynomial = MlasGatherMinimaxCoeff(5, Index); + Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(4, Index)); + Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(3, Index)); + Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(2, Index)); + Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(1, Index)); + Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(0, Index)); + + const __m512 Sign = _mm512_and_ps(X, SignMask); + const __m512 ErfPart = _mm512_xor_ps(Polynomial, Sign); + __m512 Result = _mm512_mul_ps(_mm512_mul_ps(X, _mm512_add_ps(ErfPart, One)), Half); + + const __mmask16 PositiveInfinityMask = _mm512_cmp_ps_mask(X, PositiveInfinity, _CMP_EQ_OQ); + Result = _mm512_mask_mov_ps(Result, PositiveInfinityMask, PositiveInfinity); + + const __mmask16 NegativeInfinityMask = _mm512_cmp_ps_mask(X, NegativeInfinity, _CMP_EQ_OQ); + Result = _mm512_mask_mov_ps(Result, NegativeInfinityMask, NegativeZero); + + return Result; +} + +MLAS_FORCEINLINE __m512 +MlasComputeGeluVectorExactAvx512( + __m512 X + ) +{ + const __m512 InvSqrt2 = _mm512_set1_ps(GeluAvx512Constants::InvSqrt2); + const __m512 Half = _mm512_set1_ps(GeluAvx512Constants::Half); + const __m512 One = _mm512_set1_ps(GeluAvx512Constants::One); + const __m512 ErfInput = _mm512_mul_ps(X, InvSqrt2); + const __m512 ErfValue = MlasGeluErfAvx512(ErfInput); + return _mm512_mul_ps(_mm512_mul_ps(Half, X), _mm512_add_ps(ErfValue, One)); +} + +void +MlasGeluKernelAvx512FExactImpl( + const float* Input, + float* Output, + size_t N + ) +{ + while (N >= 16) { + const __m512 X = _mm512_loadu_ps(Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X); + + _mm512_storeu_ps(Output, Result); + + Input += 16; + Output += 16; + N -= 16; + } + + while (N > 0) { + const float X = *Input++; + *Output++ = GeluAvx512Constants::Half * X * (std::erff(X * GeluAvx512Constants::InvSqrt2) + GeluAvx512Constants::One); + N -= 1; + } +} + +void +MlasGeluKernelAvx512FMinimaxApproxImpl( + const float* Input, + float* Output, + size_t N + ) +{ + size_t Offset = 0; + + while (Offset + 16 <= N) { + const __m512 X = _mm512_loadu_ps(Input + Offset); + const __m512 Result = MlasComputeGeluVectorMinimaxAvx512(X); + _mm512_storeu_ps(Output + Offset, Result); + Offset += 16; + } + + if (Offset < N) { + const __mmask16 TailMask = static_cast<__mmask16>((1u << (N - Offset)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input + Offset); + const __m512 Result = MlasComputeGeluVectorMinimaxAvx512(X); + _mm512_mask_storeu_ps(Output + Offset, TailMask, Result); + } +} + +} // namespace + +void +MLASCALL +MlasGeluKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + MlasGeluKernelAvx512FExactImpl(Input, Output, N); +} + +void +MLASCALL +MlasGeluKernelAvx512FMinimaxApprox( + const float* Input, + float* Output, + size_t N + ) +{ + MlasGeluKernelAvx512FMinimaxApproxImpl(Input, Output, N); +} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp new file mode 100644 index 0000000000000..2335c693a12ab --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -0,0 +1,157 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu_avx512f.cpp + +Abstract: + + This module implements routines to compute the SiLU function with AVX512F + intrinsics. + +--*/ + +#include + +#include "mlasi.h" + +namespace { + +struct SiluAvx512Constants { + static constexpr float Half = 0.5f; + static constexpr float One = 1.0f; + static constexpr float Two = 2.0f; + static constexpr float Ln2 = 0.693147182f; + static constexpr float Log2EF = 1.44269502f; + static constexpr float ExpLnFltMax = 88.3762589f; + static constexpr float ExpLnFltMin = -87.3365479f; + + static constexpr float P1 = 0.999999701f; + static constexpr float P2 = 0.499991506f; + static constexpr float P3 = 0.166676521f; + static constexpr float P4 = 0.0418978221f; + static constexpr float P5 = 0.00828929059f; +}; + +MLAS_FORCEINLINE __m512 +MlasExpApproxAvx512( + __m512 Value + ) +{ + const __m512 Half = _mm512_set1_ps(SiluAvx512Constants::Half); + const __m512 One = _mm512_set1_ps(SiluAvx512Constants::One); + const __m512 Two = _mm512_set1_ps(SiluAvx512Constants::Two); + const __m512 Ln2 = _mm512_set1_ps(SiluAvx512Constants::Ln2); + const __m512 Log2EF = _mm512_set1_ps(SiluAvx512Constants::Log2EF); + const __m512 ExpLnFltMax = _mm512_set1_ps(SiluAvx512Constants::ExpLnFltMax); + const __m512 ExpLnFltMin = _mm512_set1_ps(SiluAvx512Constants::ExpLnFltMin); + const __m512 P1 = _mm512_set1_ps(SiluAvx512Constants::P1); + const __m512 P2 = _mm512_set1_ps(SiluAvx512Constants::P2); + const __m512 P3 = _mm512_set1_ps(SiluAvx512Constants::P3); + const __m512 P4 = _mm512_set1_ps(SiluAvx512Constants::P4); + const __m512 P5 = _mm512_set1_ps(SiluAvx512Constants::P5); + const __m512i ExponentBias = _mm512_set1_epi32(127); + + const __mmask16 UnderflowMask = _mm512_cmp_ps_mask(Value, ExpLnFltMin, _CMP_LT_OS); + + Value = _mm512_min_ps(Value, ExpLnFltMax); + Value = _mm512_max_ps(Value, ExpLnFltMin); + + __m512 Fx = _mm512_fmadd_ps(Value, Log2EF, Half); + Fx = _mm512_floor_ps(Fx); + + const __m512 R = _mm512_fnmadd_ps(Fx, Ln2, Value); + + const __m512 NMinusOne = _mm512_sub_ps(Fx, One); + __m512i Exponent = _mm512_cvttps_epi32(NMinusOne); + Exponent = _mm512_add_epi32(Exponent, ExponentBias); + Exponent = _mm512_slli_epi32(Exponent, 23); + Exponent = _mm512_mask_mov_epi32(Exponent, UnderflowMask, _mm512_setzero_si512()); + const __m512 Pow2NMinusOne = _mm512_castsi512_ps(Exponent); + + __m512 Y = P5; + Y = _mm512_fmadd_ps(Y, R, P4); + Y = _mm512_fmadd_ps(Y, R, P3); + Y = _mm512_fmadd_ps(Y, R, P2); + Y = _mm512_fmadd_ps(Y, R, P1); + Y = _mm512_fmadd_ps(Y, R, One); + + Y = _mm512_mul_ps(Y, Pow2NMinusOne); + Y = _mm512_mul_ps(Y, Two); + return Y; +} + +MLAS_FORCEINLINE __m512 +MlasLogisticApproxAvx512( + __m512 Value + ) +{ + const __m512 One = _mm512_set1_ps(1.0f); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); + const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); + + const __m512 XAbs = _mm512_and_ps(Value, PositiveMask); + const __m512 XNeg = _mm512_or_ps(XAbs, SignMask); + + const __m512 E = MlasExpApproxAvx512(XNeg); + const __m512 Y = _mm512_div_ps(E, _mm512_add_ps(E, One)); + const __m512 OneMinusY = _mm512_sub_ps(One, Y); + const __mmask16 NegativeMask = _mm512_cmp_ps_mask(Value, Zero, _CMP_LT_OQ); + + return _mm512_mask_blend_ps(NegativeMask, OneMinusY, Y); +} + +MLAS_FORCEINLINE __m512 +MlasComputeSiluVectorAvx512( + __m512 X + ) +{ + const __m512 PositiveInfinity = _mm512_set1_ps(std::numeric_limits::infinity()); + const __m512 NegativeInfinity = _mm512_set1_ps(-std::numeric_limits::infinity()); + const __m512 NegativeZero = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); + + __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X)); + + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + const __mmask16 PositiveInfinityMask = _mm512_cmp_ps_mask(X, PositiveInfinity, _CMP_EQ_OQ); + Result = _mm512_mask_mov_ps(Result, PositiveInfinityMask, PositiveInfinity); + + const __mmask16 NegativeInfinityMask = _mm512_cmp_ps_mask(X, NegativeInfinity, _CMP_EQ_OQ); + Result = _mm512_mask_mov_ps(Result, NegativeInfinityMask, NegativeZero); + + return Result; +} + +} // namespace + +void +MLASCALL +MlasSiluKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + size_t Offset = 0; + + while (Offset + 16 <= N) { + const __m512 X = _mm512_loadu_ps(Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X); + _mm512_storeu_ps(Output + Offset, Result); + Offset += 16; + } + + if (Offset < N) { + const __mmask16 TailMask = static_cast<__mmask16>((1u << (N - Offset)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X); + _mm512_mask_storeu_ps(Output + Offset, TailMask, Result); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 3c0ee29896cd9..ca0d277cce423 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1063,6 +1063,8 @@ extern "C" { #endif MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasLogisticKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasTanhKernel; @@ -1093,6 +1095,9 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluKernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluKernelAvx512FMinimaxApprox; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernelAvx512F; #endif MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel; @@ -1437,6 +1442,9 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif #if defined(MLAS_TARGET_AMD64) + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfMinimaxKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 12dcd61b8840e..3081b8c39abaf 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -283,7 +283,10 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelSse; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; + this->GeluKernelRoutine = MlasGeluKernel; + this->GeluErfMinimaxKernelRoutine = nullptr; this->LogisticKernelRoutine = MlasLogisticKernel; + this->SiluKernelRoutine = MlasSiluKernel; this->TanhKernelRoutine = MlasTanhKernel; this->ErfKernelRoutine = MlasErfKernel; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; @@ -460,7 +463,9 @@ Return Value: // if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { - + this->GeluKernelRoutine = MlasGeluKernelAvx512F; + this->GeluErfMinimaxKernelRoutine = MlasGeluKernelAvx512FMinimaxApprox; + this->SiluKernelRoutine = MlasSiluKernelAvx512F; this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx512F; diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp new file mode 100644 index 0000000000000..0a19a83b1212a --- /dev/null +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -0,0 +1,47 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu.cpp + +Abstract: + + This module implements routines to compute the SiLU function. + +--*/ + +#include "mlasi.h" + +void +MLASCALL +MlasSiluKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe, as the computation is not elementwise. + // The caller should guarantee Input and Output do not overlap. + // The current CPU EP kernel where we call this from guarantees that. + MlasComputeLogistic(Input, Output, N); + MlasEltwiseMul(Input, Output, Output, N); +} + +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().SiluKernelRoutine(Input, Output, N); +#else + MlasSiluKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index d55973eda180f..b5995cf4311ca 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -88,16 +88,8 @@ Status Gelu::Compute(OpKernelContext* context) const { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } + MlasComputeGeluErf(p_input, p_output, narrow(count), + use_gelu_erf_minimax_approximation_ ? MlasGeluErfModeMinimaxApproximation : MlasGeluErfModeExact); }, 0); return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h index 13238028d878a..a1419a8a95f08 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.h +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/session/onnxruntime_session_options_config_keys.h" + namespace onnxruntime { template @@ -8,11 +10,14 @@ class Gelu final : public OpKernel { public: explicit Gelu(const OpKernelInfo& info) : OpKernel(info) { approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); + use_gelu_erf_minimax_approximation_ = + info.GetConfigOptions().GetConfigOrDefault(kOrtSessionOptionsMlasGeluErfUseMinimaxApproximation, "0") == "1"; } Status Compute(OpKernelContext* ctx) const override; private: std::string approximation_algorithm_; + bool use_gelu_erf_minimax_approximation_ = false; }; } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp new file mode 100644 index 0000000000000..24a34a77a905f --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include + +#include "core/mlas/lib/mlasi.h" +#include "test/mlas/bench/bench_util.h" + +namespace { + +// Compare fused MLAS unary activation paths against unfused baselines for +// SiLU and exact GELU(erf). + +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kInvSqrt2 = 0.7071067811865475244f; +constexpr int64_t kFusedBytesPerElement = 2; +constexpr int64_t kSiluUnfusedBytesPerElement = 5; +constexpr int64_t kGeluUnfusedBytesPerElement = 7; + +std::vector MakeInput(size_t n, float min_value, float max_value) { + auto data = RandomVectorUniform(n, min_value, max_value); + + if (!data.empty()) { + data[0] = 0.0f; + } + if (data.size() > 1) { + data[1] = -0.0f; + } + if (data.size() > 2) { + data[2] = -1.0f; + } + if (data.size() > 3) { + data[3] = 1.0f; + } + + return data; +} + +template +void RunDispatchedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * static_cast(n * sizeof(float) * kFusedBytesPerElement)); +} + +template +void RunUnfusedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value, + int64_t bytes_per_element) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * static_cast(n * sizeof(float) * bytes_per_element)); +} + +static void UnaryKernelArgs(benchmark::internal::Benchmark* b) { + for (int n : {1, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 511, 512, 1024, 4096, 16384, 65536, 262144}) { + b->Arg(n); + } +} + +void BM_SiluDispatch(benchmark::State& state) { + // Fused MLAS SiLU entry point. On supported platforms this may dispatch to a + // specialized implementation that combines the activation into a single + // kernel instead of exposing intermediate results. + RunDispatchedUnaryBenchmark(state, MlasComputeSilu, kSiluMinValue, kSiluMaxValue); +} + +void BM_SiluUnfusedDispatch(benchmark::State& state) { + // Unfused SiLU baseline: compute logistic(x) first and then multiply by x in + // a separate elementwise pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeLogistic(input, output, n); + MlasEltwiseMul(input, output, output, n); + }, + kSiluMinValue, + kSiluMaxValue, + kSiluUnfusedBytesPerElement); +} + +void BM_GeluErfDispatchExact(benchmark::State& state) { + // Fused MLAS GELU(erf) entry point using the exact erf-based formulation. + // On AMD64 this goes through the platform dispatch layer and may select an + // architecture-specific implementation. + RunDispatchedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeGeluErf(input, output, n, MlasGeluErfModeExact); + }, + kGeluMinValue, + kGeluMaxValue); +} + +void BM_GeluErfUnfusedExact(benchmark::State& state) { + // Unfused exact GELU baseline: scale by 1/sqrt(2), run erf, then apply the + // final 0.5 * x * (erf(x / sqrt(2)) + 1) transform in a separate pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + for (size_t i = 0; i < n; ++i) { + output[i] = input[i] * kInvSqrt2; + } + + MlasComputeErf(output, output, n); + + for (size_t i = 0; i < n; ++i) { + output[i] = 0.5f * input[i] * (output[i] + 1.0f); + } + }, + kGeluMinValue, + kGeluMaxValue, + kGeluUnfusedBytesPerElement); +} + +void BM_GeluErfDispatchMinimax(benchmark::State& state) { + if (GetMlasPlatform().GeluErfMinimaxKernelRoutine == nullptr) { + state.SkipWithError("GELU erf minimax kernel is not available on this machine."); + return; + } + + // Fused MLAS GELU(erf) entry point using the minimax erf approximation when + // the platform-specific kernel is available. + RunDispatchedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeGeluErf(input, output, n, MlasGeluErfModeMinimaxApproximation); + }, + kGeluMinValue, + kGeluMaxValue); +} + +} // namespace + +BENCHMARK(BM_SiluDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_SiluUnfusedDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfDispatchExact)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfUnfusedExact)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfDispatchMinimax)->Apply(UnaryKernelArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp new file mode 100644 index 0000000000000..af3db909cf5fd --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" + +#include +#include + +#if defined(MLAS_TARGET_AMD64) + +namespace { + +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; + +constexpr float kGeluAbsoluteTolerance = 2e-6f; +constexpr float kGeluRelativeTolerance = 2e-5f; +constexpr float kSiluAbsoluteTolerance = 3e-5f; +constexpr float kSiluRelativeTolerance = 5e-5f; + +constexpr std::array kShortTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255}; + +constexpr std::array kLongTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, + 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; + +bool IsAvx512Available() { + return GetMlasPlatform().Avx512Supported_; +} + +bool UnaryOutputsMatch(float actual, float expected, float absolute_tolerance, float relative_tolerance, + bool check_signed_zero) { + if (std::isnan(expected)) { + return std::isnan(actual); + } + + if (std::isinf(expected)) { + return std::isinf(actual) && (std::signbit(actual) == std::signbit(expected)); + } + + if (check_signed_zero && actual == 0.0f && expected == 0.0f) { + return std::signbit(actual) == std::signbit(expected); + } + + const float diff = std::fabs(actual - expected); + if (diff <= absolute_tolerance) { + return true; + } + + const float scale = std::max(std::fabs(actual), std::fabs(expected)); + return scale > 0.0f && diff <= scale * relative_tolerance; +} + +const std::vector& GetGeluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 0.0f, + -0.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + }; + + return values; +} + +const std::vector& GetSiluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 0.0f, + -0.0f, + -20.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + 20.0f, + }; + + return values; +} + +float ComputeReferenceSilu(float x) { + if (std::isnan(x)) { + return std::numeric_limits::quiet_NaN(); + } + + if (x == std::numeric_limits::infinity()) { + return x; + } + + if (x == -std::numeric_limits::infinity()) { + return -0.0f; + } + + return x / (1.0f + std::exp(-x)); +} + +void FillInput(float* input, size_t n, float minimum_value, float maximum_value, + const std::vector& special_values, uint32_t seed) { + std::mt19937 generator(seed); + std::uniform_real_distribution distribution(minimum_value, maximum_value); + + for (size_t i = 0; i < n; ++i) { + input[i] = distribution(generator); + } + + const size_t special_count = std::min(n, special_values.size()); + for (size_t i = 0; i < special_count; ++i) { + input[i] = special_values[i]; + } +} + +class MlasComputeGeluAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsAvx512Available()) { + GTEST_SKIP() << "AVX512 is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kGeluMinValue, kGeluMaxValue, GetGeluSpecialValues(), + static_cast(size * 131u + iteration * 977u + 17u)); + + MlasGeluKernel(input, generic_output, size); + MlasGeluKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Gelu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Gelu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +class MlasComputeSiluAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsAvx512Available()) { + GTEST_SKIP() << "AVX512 is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kSiluMinValue, kSiluMaxValue, GetSiluSpecialValues(), + static_cast(size * 149u + iteration * 991u + 31u)); + + MlasSiluKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + const float expected = ComputeReferenceSilu(input[i]); + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], expected, + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", expected=" << expected + << ", abs_diff=" << std::fabs(avx512_output[i] - expected); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Silu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +} // namespace + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } else { + count += MlasLongExecuteTests::RegisterLongExecute(); + count += MlasLongExecuteTests::RegisterLongExecute(); + } + return count; +}); + +#else + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool) { + return size_t{0}; +}); + +#endif diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d711e050fb913..3a8f2bfecbc29 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -754,5 +754,35 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { } #endif +TEST_F(ActivationOpTest, ONNX_Gelu_MlasErfMinimaxApproximation) { + if (GetMlasPlatform().GeluErfMinimaxKernelRoutine == nullptr) { + GTEST_SKIP() << "MLAS GELU erf minimax kernel is not available on this machine."; + } + + SessionOptions session_options; + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGeluErfUseMinimaxApproximation, "1")); + + for (const std::vector& input_vals : input_values) { + OpTester test("Gelu", 20, kOnnxDomain); + test.AddAttribute("approximate", "none"); + + std::vector dims{static_cast(input_vals.size())}; + std::vector expected_vals; + expected_vals.reserve(input_vals.size()); + for (const float x : input_vals) { + expected_vals.push_back(static_cast(0.5 * x * (1 + erf(x * M_SQRT1_2)))); + } + + test.AddInput("X", dims, input_vals); + test.AddOutput("Y", dims, expected_vals, false, 2e-5f, 2e-6f); + + test + .Config(session_options) + .ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); + } +} + } // namespace test } // namespace onnxruntime From 8b3c23a5d155509b7154a4c8c02c72c7a6c1b3a5 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 16 Mar 2026 20:40:32 -0700 Subject: [PATCH 02/37] Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index af3db909cf5fd..df70af987f3da 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -204,8 +204,8 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { for (size_t i = 0; i < size; ++i) { const float expected = ComputeReferenceSilu(input[i]); - ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], expected, - kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], expected, + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) << "Silu mismatch at index " << i << " of " << size << ", input=" << input[i] << ", avx512=" << avx512_output[i] From 26ed02581e7427cfaf7e22003348fe3ee4a2510d Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 16 Mar 2026 20:40:40 -0700 Subject: [PATCH 03/37] Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index df70af987f3da..5b7bb158ec0d1 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -25,8 +25,8 @@ constexpr std::array kShortTestSizes = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255}; constexpr std::array kLongTestSizes = { - 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, - 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, + 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; bool IsAvx512Available() { return GetMlasPlatform().Avx512Supported_; From 3bf1f0169e2fa818540d65fd624e58aec8a536b6 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 16 Mar 2026 20:40:46 -0700 Subject: [PATCH 04/37] Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 5b7bb158ec0d1..f48891ad597a2 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -156,8 +156,8 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { MlasGeluKernelAvx512F(input, avx512_output, size); for (size_t i = 0; i < size; ++i) { - ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], - kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) << "Gelu mismatch at index " << i << " of " << size << ", input=" << input[i] << ", avx512=" << avx512_output[i] From e31395d461b972930a4a1a209afc33eb1de23e0c Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 16 Mar 2026 20:52:37 -0700 Subject: [PATCH 05/37] Slight adjustments in the code --- onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp | 4 ++-- onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp index a9195bd0c9985..8f40e4cad597e 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -233,7 +233,7 @@ MlasGeluErfAvx512( BigResult = _mm512_sub_ps(ErfOne, MlasGeluErfExpVectorAvx512(BigResult)); __m512 Result = _mm512_mask_blend_ps(SplitMask, SmallResult, BigResult); - Result = _mm512_or_ps(Result, SignMask); + Result = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(Result), _mm512_castps_si512(SignMask))); return Result; } @@ -371,4 +371,4 @@ MlasGeluKernelAvx512FMinimaxApprox( ) { MlasGeluKernelAvx512FMinimaxApproxImpl(Input, Output, N); -} \ No newline at end of file +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index 2335c693a12ab..f4f828d765b7c 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -96,7 +96,7 @@ MlasLogisticApproxAvx512( const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); const __m512 XAbs = _mm512_and_ps(Value, PositiveMask); - const __m512 XNeg = _mm512_or_ps(XAbs, SignMask); + const __m512 XNeg = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(XAbs), _mm512_castps_si512(SignMask))); const __m512 E = MlasExpApproxAvx512(XNeg); const __m512 Y = _mm512_div_ps(E, _mm512_add_ps(E, One)); From 9a793a3c1c4c302cef06e1ee56b1cef239dd6863 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 16 Mar 2026 21:18:49 -0700 Subject: [PATCH 06/37] More build changes --- .../core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp | 12 ++++++------ .../core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp | 2 +- .../providers/cpu/activation/activation_op_test.cc | 4 ---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp index 8f40e4cad597e..3dd8286bc3435 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -202,8 +202,8 @@ MlasGeluErfAvx512( const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne); const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange); - const __m512 SignMask = _mm512_and_ps(Value, NegZero); - __m512 AbsValue = _mm512_andnot_ps(NegZero, Value); + const __m512 SignMask = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(NegZero))); + __m512 AbsValue = _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(NegZero), _mm512_castps_si512(Value))); AbsValue = _mm512_min_ps(ErfUpperAbsRange, AbsValue); const __m512 SquareValue = _mm512_mul_ps(AbsValue, AbsValue); @@ -228,7 +228,7 @@ MlasGeluErfAvx512( BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP6MinusOne); BigResult = _mm512_fmadd_ps(BigResult, BigInput, BigInput); - BigResult = _mm512_xor_ps(BigResult, NegZero); + BigResult = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(BigResult), _mm512_castps_si512(NegZero))); BigResult = _mm512_max_ps(ExpLowerRange, BigResult); BigResult = _mm512_sub_ps(ErfOne, MlasGeluErfExpVectorAvx512(BigResult)); @@ -256,7 +256,7 @@ MlasComputeGeluVectorMinimaxAvx512( const __m512i TwentyThreeI = _mm512_set1_epi32(23); const __m512i TwentyFourI = _mm512_set1_epi32(24); - const __m512 XPositive = _mm512_and_ps(X, PositiveMask); + const __m512 XPositive = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(X), _mm512_castps_si512(PositiveMask))); __m512i Index = _mm512_castps_si512(XPositive); Index = _mm512_add_epi32(Index, IndexBias); @@ -274,8 +274,8 @@ MlasComputeGeluVectorMinimaxAvx512( Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(1, Index)); Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(0, Index)); - const __m512 Sign = _mm512_and_ps(X, SignMask); - const __m512 ErfPart = _mm512_xor_ps(Polynomial, Sign); + const __m512 Sign = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(X), _mm512_castps_si512(SignMask))); + const __m512 ErfPart = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(Polynomial), _mm512_castps_si512(Sign))); __m512 Result = _mm512_mul_ps(_mm512_mul_ps(X, _mm512_add_ps(ErfPart, One)), Half); const __mmask16 PositiveInfinityMask = _mm512_cmp_ps_mask(X, PositiveInfinity, _CMP_EQ_OQ); diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index f4f828d765b7c..523139b8ffb9a 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -95,7 +95,7 @@ MlasLogisticApproxAvx512( const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); - const __m512 XAbs = _mm512_and_ps(Value, PositiveMask); + const __m512 XAbs = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(PositiveMask))); const __m512 XNeg = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(XAbs), _mm512_castps_si512(SignMask))); const __m512 E = MlasExpApproxAvx512(XNeg); diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 3a8f2bfecbc29..812c2e3bdb516 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -755,10 +755,6 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { #endif TEST_F(ActivationOpTest, ONNX_Gelu_MlasErfMinimaxApproximation) { - if (GetMlasPlatform().GeluErfMinimaxKernelRoutine == nullptr) { - GTEST_SKIP() << "MLAS GELU erf minimax kernel is not available on this machine."; - } - SessionOptions session_options; ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry( kOrtSessionOptionsMlasGeluErfUseMinimaxApproximation, "1")); From 2cd17b1736b495cbf5ed2dc4f688f7fed256eaed Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 16 Mar 2026 21:33:38 -0700 Subject: [PATCH 07/37] More changes --- onnxruntime/test/providers/cpu/activation/activation_op_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 812c2e3bdb516..de3dee1baf273 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -3,6 +3,7 @@ #include "activation_op_test.h" #include "core/providers/cpu/activation/activations.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" From d99bfd81058bb682ee9249d71959850a89184727 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 17 Mar 2026 09:58:14 -0700 Subject: [PATCH 08/37] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/inc/mlas.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 8df19b5cdaaf9..a2a6a3e50e6d2 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1118,6 +1118,10 @@ MlasComputeErf( size_t N ); +// +// Note: The Input and Output buffers for MlasComputeGeluErf must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// void MLASCALL MlasComputeGeluErf( @@ -1127,6 +1131,10 @@ MlasComputeGeluErf( MLAS_GELU_ERF_MODE Mode ); +// +// Note: The Input and Output buffers for MlasComputeSilu must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// void MLASCALL MlasComputeSilu( From 275b69a8688fb0b8f91eff39b8898d2fcf1f23c0 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 17 Mar 2026 09:58:59 -0700 Subject: [PATCH 09/37] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/providers/cpu/tensor/gelu.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h index a1419a8a95f08..0f9e2f41d568f 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.h +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once + #include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { From 3a80418deba92625ef427170a00463c23d46f209 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 10:00:47 -0700 Subject: [PATCH 10/37] Fix ARM build + Copilot suggestions --- onnxruntime/test/mlas/bench/bench_transcendental.cpp | 10 +++------- .../test/mlas/unittest/test_transcendental_avx512.cpp | 10 +++++----- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index 24a34a77a905f..9dfa4020ffd51 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -146,13 +146,9 @@ void BM_GeluErfUnfusedExact(benchmark::State& state) { } void BM_GeluErfDispatchMinimax(benchmark::State& state) { - if (GetMlasPlatform().GeluErfMinimaxKernelRoutine == nullptr) { - state.SkipWithError("GELU erf minimax kernel is not available on this machine."); - return; - } - - // Fused MLAS GELU(erf) entry point using the minimax erf approximation when - // the platform-specific kernel is available. + // Fused MLAS GELU(erf) entry point requesting the minimax erf mode. MLAS + // falls back to the exact GELU path when a platform-specific minimax kernel + // is not available. RunDispatchedUnaryBenchmark( state, [](const float* input, float* output, size_t n) { diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index f48891ad597a2..32e1c4470fb49 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -206,11 +206,11 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { const float expected = ComputeReferenceSilu(input[i]); ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], expected, kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) - << "Silu mismatch at index " << i << " of " << size - << ", input=" << input[i] - << ", avx512=" << avx512_output[i] - << ", expected=" << expected - << ", abs_diff=" << std::fabs(avx512_output[i] - expected); + << "Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", expected=" << expected + << ", abs_diff=" << std::fabs(avx512_output[i] - expected); } } } From a3f7033974d1d916244159eca37c7b5f7fac9268 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 17 Mar 2026 10:28:33 -0700 Subject: [PATCH 11/37] Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../test/mlas/unittest/test_transcendental_avx512.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 32e1c4470fb49..f48891ad597a2 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -206,11 +206,11 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { const float expected = ComputeReferenceSilu(input[i]); ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], expected, kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) - << "Silu mismatch at index " << i << " of " << size - << ", input=" << input[i] - << ", avx512=" << avx512_output[i] - << ", expected=" << expected - << ", abs_diff=" << std::fabs(avx512_output[i] - expected); + << "Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", expected=" << expected + << ", abs_diff=" << std::fabs(avx512_output[i] - expected); } } } From 08157abd82841431efbc9c03375058c346202bbb Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 19:06:06 -0700 Subject: [PATCH 12/37] Remove Minimax approx + address PR feedback --- .../onnxruntime_session_options_config_keys.h | 8 - onnxruntime/core/mlas/inc/mlas.h | 8 +- onnxruntime/core/mlas/lib/gelu.cpp | 14 +- .../lib/intrinsics/avx512/gelu_avx512f.cpp | 171 ------------------ onnxruntime/core/mlas/lib/mlasi.h | 2 - onnxruntime/core/mlas/lib/platform.cpp | 2 - onnxruntime/core/providers/cpu/tensor/gelu.cc | 3 +- onnxruntime/core/providers/cpu/tensor/gelu.h | 5 - .../test/mlas/bench/bench_transcendental.cpp | 16 +- .../unittest/test_transcendental_avx512.cpp | 16 +- .../cpu/activation/activation_op_test.cc | 26 --- 11 files changed, 14 insertions(+), 257 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 8d1f1c35d4cd3..f0a99bc11c8b3 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -386,14 +386,6 @@ static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm"; // - "1": Disable KleidiAI kernels even if available. static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_kleidiai"; -// Use the minimax erf approximation inside MLAS exact Gelu when available. -// Some platforms may not have this approximation available, in which case this option will have no effect. -// Option values: -// - "0": Use the default exact erf-based Gelu path. [DEFAULT] -// - "1": Use the AVX512 minimax erf approximation for exact Gelu. -static const char* const kOrtSessionOptionsMlasGeluErfUseMinimaxApproximation = - "mlas.gelu_erf_use_minimax_approximation"; - // When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. // Refer to MatMulNBits op schema for more details. // If not provided, default is 4. diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index a2a6a3e50e6d2..0b9af72b8a4cd 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1105,11 +1105,6 @@ MlasMaximumPool( // Miscellaneous compute routines. // -enum MLAS_GELU_ERF_MODE { - MlasGeluErfModeExact = 0, - MlasGeluErfModeMinimaxApproximation = 1, -}; - void MLASCALL MlasComputeErf( @@ -1127,8 +1122,7 @@ MLASCALL MlasComputeGeluErf( const float* Input, float* Output, - size_t N, - MLAS_GELU_ERF_MODE Mode + size_t N ); // diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index c282d2e261080..76948656bba77 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -46,21 +46,9 @@ MLASCALL MlasComputeGeluErf( const float* Input, float* Output, - size_t N, - MLAS_GELU_ERF_MODE Mode + size_t N ) { -#if !defined(MLAS_TARGET_AMD64) - MLAS_UNREFERENCED_PARAMETER(Mode); -#endif - -#if defined(MLAS_TARGET_AMD64) - if (Mode == MlasGeluErfModeMinimaxApproximation && GetMlasPlatform().GeluErfMinimaxKernelRoutine != nullptr) { - GetMlasPlatform().GeluErfMinimaxKernelRoutine(Input, Output, N); - return; - } -#endif - #if defined(MLAS_TARGET_AMD64) GetMlasPlatform().GeluKernelRoutine(Input, Output, N); #else diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp index 3dd8286bc3435..b7665512ba512 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -13,12 +13,9 @@ Module Name: This module implements routines to compute exact Gelu with AVX512F intrinsics. - Idea and code credit for the minimax approximation: OneDNN library - --*/ #include -#include #include "mlasi.h" @@ -59,89 +56,6 @@ struct GeluAvx512Constants { static constexpr float ExpC = 1.25829120e+7f; }; -alignas(64) static const uint32_t MlasGeluErfMinimaxTable[6][32] = { - { - 0xa6f2cb94, 0x32827792, 0x3381cc0c, 0x34523d4a, - 0x351ac44d, 0x35f36d88, 0x36ee8229, 0x37b8a3bb, - 0x3867a213, 0x3940033b, 0x3a2a5a1d, 0x3ae35863, - 0x3b7828f2, 0x3c08b14b, 0x3c515ed3, 0xbb503236, - 0xbd8d8e5e, 0xbe8abcd9, 0xbf0c19a2, 0xbeccb328, - 0x3e176ced, 0x3f470d99, 0x3f7abb28, 0x3f800000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - }, - { - 0x3f4c422a, 0x3f4c421f, 0x3f4c4207, 0x3f4c41cb, - 0x3f4c413b, 0x3f4c3fad, 0x3f4c3a2f, 0x3f4c2d40, - 0x3f4c146a, 0x3f4bc341, 0x3f4ad08c, 0x3f48f8cf, - 0x3f45fac7, 0x3f404e07, 0x3f3b980f, 0x3f48dff3, - 0x3f78b21b, 0x3fbb0704, 0x40019c32, 0x3fe536d6, - 0x3f81331e, 0x3e6c8684, 0x3c98f936, 0x00000000, - 0x3f800000, 0x00000000, 0x00000000, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - }, - { - 0xb62173f4, 0x3735e4cf, 0x37f2ff89, 0x388c23be, - 0x3917535c, 0x39ab2ab0, 0x3a60fadb, 0x3af9b960, - 0x3b6e5491, 0x3c0a4ec5, 0x3ca5aa8c, 0x3d2138d9, - 0x3d8737d4, 0x3ddfb660, 0x3e0f27ab, 0x3d94004b, - 0xbe0efdeb, 0xbf1d96c3, 0xbf89db58, 0xbf6d9897, - 0xbef69fb8, 0xbdc4f8a8, 0xbbde6422, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - }, - { - 0xbe081a19, 0xbe084570, 0xbe08639b, 0xbe089837, - 0xbe08f409, 0xbe09ab95, 0xbe0b66d0, 0xbe0e400a, - 0xbe124df8, 0xbe1bde02, 0xbe2f19c9, 0xbe4931bf, - 0xbe685fbc, 0xbe89c95f, 0xbe96cbca, 0xbe8044aa, - 0xbe0550f2, 0x3dcfd6a1, 0x3e94c826, 0x3e79345f, - 0x3decec91, 0x3ca46568, 0x3aa1e00a, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - }, - { - 0xba3d61db, 0x39f097a3, 0x3a5845dc, 0x3ab1fa35, - 0x3b0cefb8, 0x3b653ab6, 0x3bcae527, 0x3c221712, - 0x3c6c5840, 0x3cc0a703, 0x3d1dcc19, 0x3d63656d, - 0x3d955907, 0x3dbf9910, 0x3dd53f69, 0x3db7dcef, - 0x3d639ebe, 0xba6ede48, 0xbd22be69, 0xbd041cf1, - 0xbc64f5ab, 0xbb097a32, 0xb8ebf380, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - }, - { - 0x3cb7d80c, 0x3c9b6050, 0x3c978d11, 0x3c92e850, - 0x3c8d058b, 0x3c848454, 0x3c6cd623, 0x3c4c824b, - 0x3c2a7935, 0x3be0b390, 0x3b0651ac, 0xbb232f53, - 0xbbd42fa0, 0xbc2c5366, 0xbc492c9e, 0xbc2a7aa6, - 0xbbd55d04, 0xba823a76, 0x3b102aa8, 0x3ae25a7e, - 0x3a31f792, 0x38b84375, 0x3689bb5a, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, - } -}; - -MLAS_FORCEINLINE __m512 -MlasLoadMinimaxTable( - const uint32_t* Table - ) -{ - return _mm512_castsi512_ps(_mm512_load_si512(reinterpret_cast(Table))); -} - -MLAS_FORCEINLINE __m512 -MlasGatherMinimaxCoeff( - int Degree, - __m512i Index - ) -{ - const uint32_t* Base = MlasGeluErfMinimaxTable[Degree]; - const __m512 Lo = MlasLoadMinimaxTable(Base); - const __m512 Hi = MlasLoadMinimaxTable(Base + 16); - return _mm512_permutex2var_ps(Lo, Index, Hi); -} - MLAS_FORCEINLINE __m512 MlasGeluErfExpVectorAvx512( __m512 Value @@ -237,56 +151,6 @@ MlasGeluErfAvx512( return Result; } -MLAS_FORCEINLINE __m512 -MlasComputeGeluVectorMinimaxAvx512( - __m512 X - ) -{ - const __m512 PositiveInfinity = _mm512_set1_ps(std::numeric_limits::infinity()); - const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); - const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); - const __m512 One = _mm512_set1_ps(1.0f); - const __m512 Half = _mm512_set1_ps(0.5f); - const __m512 NegativeInfinity = _mm512_set1_ps(-std::numeric_limits::infinity()); - const __m512 NegativeZero = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); - const __m512 RightBound = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x40b15ceeu))); - - const __m512i IndexBias = _mm512_set1_epi32(static_cast(0xc21fffff)); - const __m512i OneI = _mm512_set1_epi32(1); - const __m512i TwentyThreeI = _mm512_set1_epi32(23); - const __m512i TwentyFourI = _mm512_set1_epi32(24); - - const __m512 XPositive = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(X), _mm512_castps_si512(PositiveMask))); - - __m512i Index = _mm512_castps_si512(XPositive); - Index = _mm512_add_epi32(Index, IndexBias); - Index = _mm512_srai_epi32(Index, 21); - Index = _mm512_max_epi32(Index, OneI); - Index = _mm512_min_epi32(Index, TwentyFourI); - - const __mmask16 GreaterThanRightBoundMask = _mm512_cmp_ps_mask(XPositive, RightBound, _CMP_GT_OQ); - Index = _mm512_mask_blend_epi32(GreaterThanRightBoundMask, Index, TwentyThreeI); - - __m512 Polynomial = MlasGatherMinimaxCoeff(5, Index); - Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(4, Index)); - Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(3, Index)); - Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(2, Index)); - Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(1, Index)); - Polynomial = _mm512_fmadd_ps(Polynomial, XPositive, MlasGatherMinimaxCoeff(0, Index)); - - const __m512 Sign = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(X), _mm512_castps_si512(SignMask))); - const __m512 ErfPart = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(Polynomial), _mm512_castps_si512(Sign))); - __m512 Result = _mm512_mul_ps(_mm512_mul_ps(X, _mm512_add_ps(ErfPart, One)), Half); - - const __mmask16 PositiveInfinityMask = _mm512_cmp_ps_mask(X, PositiveInfinity, _CMP_EQ_OQ); - Result = _mm512_mask_mov_ps(Result, PositiveInfinityMask, PositiveInfinity); - - const __mmask16 NegativeInfinityMask = _mm512_cmp_ps_mask(X, NegativeInfinity, _CMP_EQ_OQ); - Result = _mm512_mask_mov_ps(Result, NegativeInfinityMask, NegativeZero); - - return Result; -} - MLAS_FORCEINLINE __m512 MlasComputeGeluVectorExactAvx512( __m512 X @@ -325,30 +189,6 @@ MlasGeluKernelAvx512FExactImpl( } } -void -MlasGeluKernelAvx512FMinimaxApproxImpl( - const float* Input, - float* Output, - size_t N - ) -{ - size_t Offset = 0; - - while (Offset + 16 <= N) { - const __m512 X = _mm512_loadu_ps(Input + Offset); - const __m512 Result = MlasComputeGeluVectorMinimaxAvx512(X); - _mm512_storeu_ps(Output + Offset, Result); - Offset += 16; - } - - if (Offset < N) { - const __mmask16 TailMask = static_cast<__mmask16>((1u << (N - Offset)) - 1u); - const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input + Offset); - const __m512 Result = MlasComputeGeluVectorMinimaxAvx512(X); - _mm512_mask_storeu_ps(Output + Offset, TailMask, Result); - } -} - } // namespace void @@ -361,14 +201,3 @@ MlasGeluKernelAvx512F( { MlasGeluKernelAvx512FExactImpl(Input, Output, N); } - -void -MLASCALL -MlasGeluKernelAvx512FMinimaxApprox( - const float* Input, - float* Output, - size_t N - ) -{ - MlasGeluKernelAvx512FMinimaxApproxImpl(Input, Output, N); -} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ca0d277cce423..b670459792ae3 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1096,7 +1096,6 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluKernelAvx512F; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluKernelAvx512FMinimaxApprox; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernelAvx512F; #endif @@ -1443,7 +1442,6 @@ struct MLAS_PLATFORM { #endif #if defined(MLAS_TARGET_AMD64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluKernelRoutine; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfMinimaxKernelRoutine; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3081b8c39abaf..eee515c6d9960 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -284,7 +284,6 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; this->GeluKernelRoutine = MlasGeluKernel; - this->GeluErfMinimaxKernelRoutine = nullptr; this->LogisticKernelRoutine = MlasLogisticKernel; this->SiluKernelRoutine = MlasSiluKernel; this->TanhKernelRoutine = MlasTanhKernel; @@ -464,7 +463,6 @@ Return Value: if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { this->GeluKernelRoutine = MlasGeluKernelAvx512F; - this->GeluErfMinimaxKernelRoutine = MlasGeluKernelAvx512FMinimaxApprox; this->SiluKernelRoutine = MlasSiluKernelAvx512F; this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index b5995cf4311ca..28bb63561fac6 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -88,8 +88,7 @@ Status Gelu::Compute(OpKernelContext* context) const { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - MlasComputeGeluErf(p_input, p_output, narrow(count), - use_gelu_erf_minimax_approximation_ ? MlasGeluErfModeMinimaxApproximation : MlasGeluErfModeExact); + MlasComputeGeluErf(p_input, p_output, narrow(count)); }, 0); return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h index 0f9e2f41d568f..14a070609a69b 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.h +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -3,8 +3,6 @@ #pragma once -#include "core/session/onnxruntime_session_options_config_keys.h" - namespace onnxruntime { template @@ -12,14 +10,11 @@ class Gelu final : public OpKernel { public: explicit Gelu(const OpKernelInfo& info) : OpKernel(info) { approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); - use_gelu_erf_minimax_approximation_ = - info.GetConfigOptions().GetConfigOrDefault(kOrtSessionOptionsMlasGeluErfUseMinimaxApproximation, "0") == "1"; } Status Compute(OpKernelContext* ctx) const override; private: std::string approximation_algorithm_; - bool use_gelu_erf_minimax_approximation_ = false; }; } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index 9dfa4020ffd51..02adf4590aa74 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -118,7 +118,7 @@ void BM_GeluErfDispatchExact(benchmark::State& state) { RunDispatchedUnaryBenchmark( state, [](const float* input, float* output, size_t n) { - MlasComputeGeluErf(input, output, n, MlasGeluErfModeExact); + MlasComputeGeluErf(input, output, n); }, kGeluMinValue, kGeluMaxValue); @@ -145,23 +145,9 @@ void BM_GeluErfUnfusedExact(benchmark::State& state) { kGeluUnfusedBytesPerElement); } -void BM_GeluErfDispatchMinimax(benchmark::State& state) { - // Fused MLAS GELU(erf) entry point requesting the minimax erf mode. MLAS - // falls back to the exact GELU path when a platform-specific minimax kernel - // is not available. - RunDispatchedUnaryBenchmark( - state, - [](const float* input, float* output, size_t n) { - MlasComputeGeluErf(input, output, n, MlasGeluErfModeMinimaxApproximation); - }, - kGeluMinValue, - kGeluMaxValue); -} - } // namespace BENCHMARK(BM_SiluDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); BENCHMARK(BM_SiluUnfusedDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); BENCHMARK(BM_GeluErfDispatchExact)->Apply(UnaryKernelArgs)->UseRealTime(); BENCHMARK(BM_GeluErfUnfusedExact)->Apply(UnaryKernelArgs)->UseRealTime(); -BENCHMARK(BM_GeluErfDispatchMinimax)->Apply(UnaryKernelArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index f48891ad597a2..138613d46d21c 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -28,8 +28,12 @@ constexpr std::array kLongTestSizes = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; -bool IsAvx512Available() { - return GetMlasPlatform().Avx512Supported_; +bool IsGeluAvx512Dispatched() { + return GetMlasPlatform().GeluKernelRoutine == MlasGeluKernelAvx512F; +} + +bool IsSiluAvx512Dispatched() { + return GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F; } bool UnaryOutputsMatch(float actual, float expected, float absolute_tolerance, float relative_tolerance, @@ -139,8 +143,8 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { MatrixGuardBuffer avx512_output_buffer_; void ExecuteCommon(const std::vector& sizes, size_t iterations) { - if (!IsAvx512Available()) { - GTEST_SKIP() << "AVX512 is not available on this machine."; + if (!IsGeluAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F GELU dispatch is not available on this machine."; } for (size_t size : sizes) { @@ -188,8 +192,8 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { MatrixGuardBuffer avx512_output_buffer_; void ExecuteCommon(const std::vector& sizes, size_t iterations) { - if (!IsAvx512Available()) { - GTEST_SKIP() << "AVX512 is not available on this machine."; + if (!IsSiluAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F SiLU dispatch is not available on this machine."; } for (size_t size : sizes) { diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index de3dee1baf273..c46a0528079fb 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -755,31 +755,5 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { } #endif -TEST_F(ActivationOpTest, ONNX_Gelu_MlasErfMinimaxApproximation) { - SessionOptions session_options; - ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry( - kOrtSessionOptionsMlasGeluErfUseMinimaxApproximation, "1")); - - for (const std::vector& input_vals : input_values) { - OpTester test("Gelu", 20, kOnnxDomain); - test.AddAttribute("approximate", "none"); - - std::vector dims{static_cast(input_vals.size())}; - std::vector expected_vals; - expected_vals.reserve(input_vals.size()); - for (const float x : input_vals) { - expected_vals.push_back(static_cast(0.5 * x * (1 + erf(x * M_SQRT1_2)))); - } - - test.AddInput("X", dims, input_vals); - test.AddOutput("Y", dims, expected_vals, false, 2e-5f, 2e-6f); - - test - .Config(session_options) - .ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); - } -} - } // namespace test } // namespace onnxruntime From 7d2342581a94b725899b02f130604e8cbad4affb Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 20 Mar 2026 19:27:51 -0700 Subject: [PATCH 13/37] Update onnxruntime/test/providers/cpu/activation/activation_op_test.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/test/providers/cpu/activation/activation_op_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index c46a0528079fb..d711e050fb913 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -3,7 +3,6 @@ #include "activation_op_test.h" #include "core/providers/cpu/activation/activations.h" -#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" From 9f9bac7708151cc72e7579dae263f6a9be6b5454 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 20 Mar 2026 19:28:09 -0700 Subject: [PATCH 14/37] Update onnxruntime/test/mlas/bench/bench_transcendental.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/test/mlas/bench/bench_transcendental.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index 02adf4590aa74..b03f4e45c41a0 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -5,7 +5,7 @@ #include -#include "core/mlas/lib/mlasi.h" +#include "mlas.h" #include "test/mlas/bench/bench_util.h" namespace { From dc63f2fceb98a488600dcdd9563fcb64e88416e2 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 19:32:07 -0700 Subject: [PATCH 15/37] Copilot comments --- .../core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp index b7665512ba512..4cdd6c45c5600 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -15,8 +15,6 @@ Module Name: --*/ -#include - #include "mlasi.h" namespace { @@ -182,10 +180,12 @@ MlasGeluKernelAvx512FExactImpl( N -= 16; } - while (N > 0) { - const float X = *Input++; - *Output++ = GeluAvx512Constants::Half * X * (std::erff(X * GeluAvx512Constants::InvSqrt2) + GeluAvx512Constants::One); - N -= 1; + if (N > 0) { + const __mmask16 TailMask = __mmask16((1u << static_cast(N)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X); + + _mm512_mask_storeu_ps(Output, TailMask, Result); } } From b8ada7649592d85c87e8604e59ff31d2284c1313 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 20 Mar 2026 19:41:38 -0700 Subject: [PATCH 16/37] Update onnxruntime/core/mlas/lib/silu.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/lib/silu.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp index 0a19a83b1212a..c8422d70374f2 100644 --- a/onnxruntime/core/mlas/lib/silu.cpp +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -25,8 +25,7 @@ MlasSiluKernel( ) { // This kernel is not buffer alias safe, as the computation is not elementwise. - // The caller should guarantee Input and Output do not overlap. - // The current CPU EP kernel where we call this from guarantees that. + // Callers must ensure that Input and Output do not overlap (see mlas.h for details). MlasComputeLogistic(Input, Output, N); MlasEltwiseMul(Input, Output, Output, N); } From ad7a0c25ae3deadc52cad3d84b44a0f090248648 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 20 Mar 2026 19:42:03 -0700 Subject: [PATCH 17/37] Update onnxruntime/core/mlas/lib/gelu.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/lib/gelu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index 76948656bba77..b8ba8cd554552 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -28,8 +28,8 @@ MlasGeluKernel( ) { // This kernel is not buffer alias safe, as the computation is not elementwise. - // The caller should guarantee Input and Output do not overlap. - // The current CPU EP kernel where we call this from guarantees that. + // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). + // for (size_t i = 0; i < N; ++i) { Output[i] = Input[i] * static_cast(M_SQRT1_2); } From cbffffa94a26763997b417ceb4814c6094fdcce4 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 19:44:47 -0700 Subject: [PATCH 18/37] Adjust comment --- onnxruntime/core/mlas/lib/silu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp index c8422d70374f2..a71d86fc38d5d 100644 --- a/onnxruntime/core/mlas/lib/silu.cpp +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -25,7 +25,7 @@ MlasSiluKernel( ) { // This kernel is not buffer alias safe, as the computation is not elementwise. - // Callers must ensure that Input and Output do not overlap (see mlas.h for details). + // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). MlasComputeLogistic(Input, Output, N); MlasEltwiseMul(Input, Output, Output, N); } From 50980aa9598dbe4cc27272d894e1bde5576f130d Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 20:01:22 -0700 Subject: [PATCH 19/37] Copilot comments --- onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index 523139b8ffb9a..61219326820e4 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -56,7 +56,7 @@ MlasExpApproxAvx512( const __m512 P5 = _mm512_set1_ps(SiluAvx512Constants::P5); const __m512i ExponentBias = _mm512_set1_epi32(127); - const __mmask16 UnderflowMask = _mm512_cmp_ps_mask(Value, ExpLnFltMin, _CMP_LT_OS); + const __mmask16 UnderflowMask = _mm512_cmp_ps_mask(Value, ExpLnFltMin, _CMP_LT_OQ); Value = _mm512_min_ps(Value, ExpLnFltMax); Value = _mm512_max_ps(Value, ExpLnFltMin); From da555836e9072d9b48d33338f0b180ab9f48e1ac Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 20:21:49 -0700 Subject: [PATCH 20/37] Copilot comments --- .../lib/intrinsics/avx512/gelu_avx512f.cpp | 5 +++- .../lib/intrinsics/avx512/silu_avx512f.cpp | 9 ++++-- .../unittest/test_transcendental_avx512.cpp | 30 +++++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp index 4cdd6c45c5600..358b882b883be 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -15,11 +15,14 @@ Module Name: --*/ +#include + #include "mlasi.h" namespace { struct GeluAvx512Constants { + static constexpr int32_t SignBitMask = INT32_MIN; static constexpr float InvSqrt2 = 0.70710678118654752440f; static constexpr float Half = 0.5f; static constexpr float One = 1.0f; @@ -94,7 +97,7 @@ MlasGeluErfAvx512( __m512 Value ) { - const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); + const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(GeluAvx512Constants::SignBitMask)); const __m512 Zero = _mm512_setzero_ps(); const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange); const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary); diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index 61219326820e4..2c2e118d26d51 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -15,6 +15,7 @@ Module Name: --*/ +#include #include #include "mlasi.h" @@ -22,6 +23,8 @@ Module Name: namespace { struct SiluAvx512Constants { + static constexpr int32_t SignBitMask = INT32_MIN; + static constexpr int32_t PositiveMask = INT32_MAX; static constexpr float Half = 0.5f; static constexpr float One = 1.0f; static constexpr float Two = 2.0f; @@ -92,8 +95,8 @@ MlasLogisticApproxAvx512( { const __m512 One = _mm512_set1_ps(1.0f); const __m512 Zero = _mm512_setzero_ps(); - const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); - const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); + const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::SignBitMask)); + const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::PositiveMask)); const __m512 XAbs = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(PositiveMask))); const __m512 XNeg = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(XAbs), _mm512_castps_si512(SignMask))); @@ -113,7 +116,7 @@ MlasComputeSiluVectorAvx512( { const __m512 PositiveInfinity = _mm512_set1_ps(std::numeric_limits::infinity()); const __m512 NegativeInfinity = _mm512_set1_ps(-std::numeric_limits::infinity()); - const __m512 NegativeZero = _mm512_castsi512_ps(_mm512_set1_epi32(int(0x80000000u))); + const __m512 NegativeZero = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::SignBitMask)); __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X)); diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 138613d46d21c..3366819ffeb17 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -140,6 +140,7 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { private: MatrixGuardBuffer input_buffer_; MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer public_output_buffer_; MatrixGuardBuffer avx512_output_buffer_; void ExecuteCommon(const std::vector& sizes, size_t iterations) { @@ -151,15 +152,25 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { for (size_t iteration = 0; iteration < iterations; ++iteration) { float* input = input_buffer_.GetBuffer(size); float* generic_output = generic_output_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); float* avx512_output = avx512_output_buffer_.GetBuffer(size); FillInput(input, size, kGeluMinValue, kGeluMaxValue, GetGeluSpecialValues(), static_cast(size * 131u + iteration * 977u + 17u)); MlasGeluKernel(input, generic_output, size); + MlasComputeGeluErf(input, public_output, size); MlasGeluKernelAvx512F(input, avx512_output, size); for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public Gelu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) << "Gelu mismatch at index " << i << " of " << size @@ -189,6 +200,7 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { class MlasComputeSiluAvx512Test : public MlasTestBase { private: MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer public_output_buffer_; MatrixGuardBuffer avx512_output_buffer_; void ExecuteCommon(const std::vector& sizes, size_t iterations) { @@ -199,15 +211,25 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { for (size_t size : sizes) { for (size_t iteration = 0; iteration < iterations; ++iteration) { float* input = input_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); float* avx512_output = avx512_output_buffer_.GetBuffer(size); FillInput(input, size, kSiluMinValue, kSiluMaxValue, GetSiluSpecialValues(), static_cast(size * 149u + iteration * 991u + 31u)); + MlasComputeSilu(input, public_output, size); MlasSiluKernelAvx512F(input, avx512_output, size); for (size_t i = 0; i < size; ++i) { const float expected = ComputeReferenceSilu(input[i]); + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], expected, + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", expected=" << expected + << ", abs_diff=" << std::fabs(public_output[i] - expected); + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], expected, kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) << "Silu mismatch at index " << i << " of " << size @@ -215,6 +237,14 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { << ", avx512=" << avx512_output[i] << ", expected=" << expected << ", abs_diff=" << std::fabs(avx512_output[i] - expected); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public/API Silu dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); } } } From bb3cd555e362af612d8f6f2b6ecdbf291b242a11 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 20 Mar 2026 20:29:39 -0700 Subject: [PATCH 21/37] Update onnxruntime/test/mlas/bench/bench_transcendental.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/test/mlas/bench/bench_transcendental.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index b03f4e45c41a0..2a02a33214032 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -6,7 +6,7 @@ #include #include "mlas.h" -#include "test/mlas/bench/bench_util.h" +#include "bench_util.h" namespace { From f6a22fcddb0251636b04325db3025ed44e9b7381 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 20:43:16 -0700 Subject: [PATCH 22/37] Adjust AVX512 path to match generic path for special values --- onnxruntime/core/mlas/lib/gelu.cpp | 8 ++++- .../lib/intrinsics/avx512/silu_avx512f.cpp | 10 ------ onnxruntime/core/mlas/lib/silu.cpp | 13 ++++++++ .../unittest/test_transcendental_avx512.cpp | 32 ++++++------------- 4 files changed, 29 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index b8ba8cd554552..6ddada8a327e8 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -18,6 +18,12 @@ Module Name: #include "mlasi.h" +namespace { + +constexpr float kInvSqrt2 = 0.70710678118654752440f; + +} // namespace + void MLASCALL @@ -31,7 +37,7 @@ MlasGeluKernel( // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). // for (size_t i = 0; i < N; ++i) { - Output[i] = Input[i] * static_cast(M_SQRT1_2); + Output[i] = Input[i] * kInvSqrt2; } MlasComputeErf(Output, Output, N); diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index 2c2e118d26d51..def943b7c6aae 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -114,21 +114,11 @@ MlasComputeSiluVectorAvx512( __m512 X ) { - const __m512 PositiveInfinity = _mm512_set1_ps(std::numeric_limits::infinity()); - const __m512 NegativeInfinity = _mm512_set1_ps(-std::numeric_limits::infinity()); - const __m512 NegativeZero = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::SignBitMask)); - __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X)); const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); Result = _mm512_mask_mov_ps(Result, NaNMask, X); - const __mmask16 PositiveInfinityMask = _mm512_cmp_ps_mask(X, PositiveInfinity, _CMP_EQ_OQ); - Result = _mm512_mask_mov_ps(Result, PositiveInfinityMask, PositiveInfinity); - - const __mmask16 NegativeInfinityMask = _mm512_cmp_ps_mask(X, NegativeInfinity, _CMP_EQ_OQ); - Result = _mm512_mask_mov_ps(Result, NegativeInfinityMask, NegativeZero); - return Result; } diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp index a71d86fc38d5d..d7a4536408cee 100644 --- a/onnxruntime/core/mlas/lib/silu.cpp +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -14,6 +14,8 @@ Module Name: --*/ +#include + #include "mlasi.h" void @@ -26,8 +28,19 @@ MlasSiluKernel( { // This kernel is not buffer alias safe, as the computation is not elementwise. // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). + const float PositiveInfinity = std::numeric_limits::infinity(); + const float NegativeInfinity = -std::numeric_limits::infinity(); + MlasComputeLogistic(Input, Output, N); MlasEltwiseMul(Input, Output, Output, N); + + for (size_t i = 0; i < N; ++i) { + if (Input[i] == PositiveInfinity) { + Output[i] = PositiveInfinity; + } else if (Input[i] == NegativeInfinity) { + Output[i] = -0.0f; + } + } } void diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 3366819ffeb17..5811a9280ace8 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -105,22 +105,6 @@ const std::vector& GetSiluSpecialValues() { return values; } -float ComputeReferenceSilu(float x) { - if (std::isnan(x)) { - return std::numeric_limits::quiet_NaN(); - } - - if (x == std::numeric_limits::infinity()) { - return x; - } - - if (x == -std::numeric_limits::infinity()) { - return -0.0f; - } - - return x / (1.0f + std::exp(-x)); -} - void FillInput(float* input, size_t n, float minimum_value, float maximum_value, const std::vector& special_values, uint32_t seed) { std::mt19937 generator(seed); @@ -200,6 +184,7 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { class MlasComputeSiluAvx512Test : public MlasTestBase { private: MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; MatrixGuardBuffer public_output_buffer_; MatrixGuardBuffer avx512_output_buffer_; @@ -211,32 +196,33 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { for (size_t size : sizes) { for (size_t iteration = 0; iteration < iterations; ++iteration) { float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); float* public_output = public_output_buffer_.GetBuffer(size); float* avx512_output = avx512_output_buffer_.GetBuffer(size); FillInput(input, size, kSiluMinValue, kSiluMaxValue, GetSiluSpecialValues(), static_cast(size * 149u + iteration * 991u + 31u)); + MlasSiluKernel(input, generic_output, size); MlasComputeSilu(input, public_output, size); MlasSiluKernelAvx512F(input, avx512_output, size); for (size_t i = 0; i < size; ++i) { - const float expected = ComputeReferenceSilu(input[i]); - ASSERT_TRUE(UnaryOutputsMatch(public_output[i], expected, + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) << "Public Silu mismatch at index " << i << " of " << size << ", input=" << input[i] << ", public=" << public_output[i] - << ", expected=" << expected - << ", abs_diff=" << std::fabs(public_output[i] - expected); + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); - ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], expected, + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) << "Silu mismatch at index " << i << " of " << size << ", input=" << input[i] << ", avx512=" << avx512_output[i] - << ", expected=" << expected - << ", abs_diff=" << std::fabs(avx512_output[i] - expected); + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) From 9a44721e5e02de5fc073c8725fa667a03871dd7b Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 21:03:57 -0700 Subject: [PATCH 23/37] a --- .../core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index def943b7c6aae..3ee36a4c6c0fc 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -114,11 +114,20 @@ MlasComputeSiluVectorAvx512( __m512 X ) { + const __m512 PositiveInfinity = _mm512_set1_ps(std::numeric_limits::infinity()); + const __m512 NegativeInfinity = _mm512_set1_ps(-std::numeric_limits::infinity()); + __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X)); const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); Result = _mm512_mask_mov_ps(Result, NaNMask, X); + const __mmask16 PositiveInfinityMask = _mm512_cmp_ps_mask(X, PositiveInfinity, _CMP_EQ_OQ); + Result = _mm512_mask_mov_ps(Result, PositiveInfinityMask, PositiveInfinity); + + const __mmask16 NegativeInfinityMask = _mm512_cmp_ps_mask(X, NegativeInfinity, _CMP_EQ_OQ); + Result = _mm512_mask_mov_ps(Result, NegativeInfinityMask, NegativeInfinity); + return Result; } From 1f4f3a17c6c1c5b5cf76376e53bc2e5ee560f728 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 22:40:05 -0700 Subject: [PATCH 24/37] Revert accidental change --- onnxruntime/core/mlas/lib/silu.cpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp index d7a4536408cee..8733c45088ab1 100644 --- a/onnxruntime/core/mlas/lib/silu.cpp +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -28,19 +28,8 @@ MlasSiluKernel( { // This kernel is not buffer alias safe, as the computation is not elementwise. // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). - const float PositiveInfinity = std::numeric_limits::infinity(); - const float NegativeInfinity = -std::numeric_limits::infinity(); - MlasComputeLogistic(Input, Output, N); MlasEltwiseMul(Input, Output, Output, N); - - for (size_t i = 0; i < N; ++i) { - if (Input[i] == PositiveInfinity) { - Output[i] = PositiveInfinity; - } else if (Input[i] == NegativeInfinity) { - Output[i] = -0.0f; - } - } } void From 0f7db94be0dfff750288b2b54afe30344e1b1bd6 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 20 Mar 2026 23:06:15 -0700 Subject: [PATCH 25/37] Update onnxruntime/core/mlas/lib/silu.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/lib/silu.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp index 8733c45088ab1..a71d86fc38d5d 100644 --- a/onnxruntime/core/mlas/lib/silu.cpp +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -14,8 +14,6 @@ Module Name: --*/ -#include - #include "mlasi.h" void From 84cf795c076d76023810ad07e86b6b68e8c8ee57 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 20 Mar 2026 23:06:27 -0700 Subject: [PATCH 26/37] Update onnxruntime/core/mlas/lib/gelu.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/lib/gelu.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index 6ddada8a327e8..d35d90208f244 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -14,8 +14,6 @@ Module Name: --*/ -#include - #include "mlasi.h" namespace { From 7c84be765543587e713ed4110f9875320b0f5414 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 22 Mar 2026 14:21:41 -0700 Subject: [PATCH 27/37] Copilot comment --- .../test/mlas/bench/bench_transcendental.cpp | 42 +++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index 2a02a33214032..f7e1f99dbf655 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -7,6 +7,7 @@ #include "mlas.h" #include "bench_util.h" +#include "core/mlas/lib/mlasi.h" namespace { @@ -22,6 +23,35 @@ constexpr int64_t kFusedBytesPerElement = 2; constexpr int64_t kSiluUnfusedBytesPerElement = 5; constexpr int64_t kGeluUnfusedBytesPerElement = 7; +struct DispatchedUnaryPathInfo { + int64_t bytes_per_element; + const char* label; +}; + +DispatchedUnaryPathInfo GetSiluDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic path, + // which materializes the logistic result before the final multiply. + return {kSiluUnfusedBytesPerElement, "generic_fallback"}; +} + +DispatchedUnaryPathInfo GetGeluDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().GeluKernelRoutine == MlasGeluKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic exact + // GELU(erf) implementation, which uses separate scale, erf, and final passes. + return {kGeluUnfusedBytesPerElement, "generic_fallback"}; +} + std::vector MakeInput(size_t n, float min_value, float max_value) { auto data = RandomVectorUniform(n, min_value, max_value); @@ -45,11 +75,14 @@ template void RunDispatchedUnaryBenchmark(benchmark::State& state, KernelFn&& kernel, float min_value, - float max_value) { + float max_value, + DispatchedUnaryPathInfo path_info) { const auto n = static_cast(state.range(0)); auto input = MakeInput(n, min_value, max_value); std::vector output(n); + state.SetLabel(path_info.label); + kernel(input.data(), output.data(), n); for (auto _ : state) { @@ -59,7 +92,7 @@ void RunDispatchedUnaryBenchmark(benchmark::State& state, } state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); - state.SetBytesProcessed(static_cast(state.iterations()) * static_cast(n * sizeof(float) * kFusedBytesPerElement)); + state.SetBytesProcessed(static_cast(state.iterations()) * static_cast(n * sizeof(float) * path_info.bytes_per_element)); } template @@ -94,7 +127,7 @@ void BM_SiluDispatch(benchmark::State& state) { // Fused MLAS SiLU entry point. On supported platforms this may dispatch to a // specialized implementation that combines the activation into a single // kernel instead of exposing intermediate results. - RunDispatchedUnaryBenchmark(state, MlasComputeSilu, kSiluMinValue, kSiluMaxValue); + RunDispatchedUnaryBenchmark(state, MlasComputeSilu, kSiluMinValue, kSiluMaxValue, GetSiluDispatchPathInfo()); } void BM_SiluUnfusedDispatch(benchmark::State& state) { @@ -121,7 +154,8 @@ void BM_GeluErfDispatchExact(benchmark::State& state) { MlasComputeGeluErf(input, output, n); }, kGeluMinValue, - kGeluMaxValue); + kGeluMaxValue, + GetGeluDispatchPathInfo()); } void BM_GeluErfUnfusedExact(benchmark::State& state) { From c9a549cdfe5777fe874ada383baf52e54a58a5e1 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 22 Mar 2026 14:32:02 -0700 Subject: [PATCH 28/37] Copilot comment --- .../mlas/lib/intrinsics/avx512/silu_avx512f.cpp | 13 +++++++++++-- .../mlas/unittest/test_transcendental_avx512.cpp | 4 ++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index 3ee36a4c6c0fc..bf255678702bd 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -28,6 +28,8 @@ struct SiluAvx512Constants { static constexpr float Half = 0.5f; static constexpr float One = 1.0f; static constexpr float Two = 2.0f; + static constexpr float LogisticLowerRange = -18.0f; + static constexpr float LogisticUpperRange = 18.0f; static constexpr float Ln2 = 0.693147182f; static constexpr float Log2EF = 1.44269502f; static constexpr float ExpLnFltMax = 88.3762589f; @@ -95,16 +97,23 @@ MlasLogisticApproxAvx512( { const __m512 One = _mm512_set1_ps(1.0f); const __m512 Zero = _mm512_setzero_ps(); + const __m512 LogisticLowerRange = _mm512_set1_ps(SiluAvx512Constants::LogisticLowerRange); + const __m512 LogisticUpperRange = _mm512_set1_ps(SiluAvx512Constants::LogisticUpperRange); const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::SignBitMask)); const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::PositiveMask)); - const __m512 XAbs = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(PositiveMask))); + // Mirror MlasComputeLogistic clamping so the AVX512 SiLU path matches the + // existing MLAS semantics for large finite inputs and avoids x * 0 invalid + // behavior when the exp approximation underflows. + const __m512 ClampedValue = _mm512_max_ps(_mm512_min_ps(Value, LogisticUpperRange), LogisticLowerRange); + + const __m512 XAbs = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(ClampedValue), _mm512_castps_si512(PositiveMask))); const __m512 XNeg = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(XAbs), _mm512_castps_si512(SignMask))); const __m512 E = MlasExpApproxAvx512(XNeg); const __m512 Y = _mm512_div_ps(E, _mm512_add_ps(E, One)); const __m512 OneMinusY = _mm512_sub_ps(One, Y); - const __mmask16 NegativeMask = _mm512_cmp_ps_mask(Value, Zero, _CMP_LT_OQ); + const __mmask16 NegativeMask = _mm512_cmp_ps_mask(ClampedValue, Zero, _CMP_LT_OQ); return _mm512_mask_blend_ps(NegativeMask, OneMinusY, Y); } diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 5811a9280ace8..35c989e49d7b1 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -86,6 +86,10 @@ const std::vector& GetSiluSpecialValues() { std::numeric_limits::quiet_NaN(), std::numeric_limits::infinity(), -std::numeric_limits::infinity(), + std::numeric_limits::max(), + -std::numeric_limits::max(), + 1.0e9f, + -1.0e9f, 0.0f, -0.0f, -20.0f, From 05ca9c2009b3cd6138ca902e95d6dcbe4863a5d1 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 22 Mar 2026 15:28:34 -0700 Subject: [PATCH 29/37] Rework Silu --- .../lib/intrinsics/avx512/silu_avx512f.cpp | 133 ++++++------------ 1 file changed, 46 insertions(+), 87 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index bf255678702bd..67c8224be74da 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -15,7 +15,6 @@ Module Name: --*/ -#include #include #include "mlasi.h" @@ -23,99 +22,65 @@ Module Name: namespace { struct SiluAvx512Constants { - static constexpr int32_t SignBitMask = INT32_MIN; - static constexpr int32_t PositiveMask = INT32_MAX; - static constexpr float Half = 0.5f; - static constexpr float One = 1.0f; - static constexpr float Two = 2.0f; static constexpr float LogisticLowerRange = -18.0f; static constexpr float LogisticUpperRange = 18.0f; - static constexpr float Ln2 = 0.693147182f; - static constexpr float Log2EF = 1.44269502f; - static constexpr float ExpLnFltMax = 88.3762589f; - static constexpr float ExpLnFltMin = -87.3365479f; - - static constexpr float P1 = 0.999999701f; - static constexpr float P2 = 0.499991506f; - static constexpr float P3 = 0.166676521f; - static constexpr float P4 = 0.0418978221f; - static constexpr float P5 = 0.00828929059f; + static constexpr float Alpha9 = 4.37031012579801e-11f; + static constexpr float Alpha7 = 1.15627324459942e-07f; + static constexpr float Alpha5 = 6.08574864600143e-05f; + static constexpr float Alpha3 = 8.51377133304701e-03f; + static constexpr float Alpha1 = 2.48287947061529e-01f; + static constexpr float Beta10 = 6.10247389755681e-13f; + static constexpr float Beta8 = 5.76102136993427e-09f; + static constexpr float Beta6 = 6.29106785017040e-06f; + static constexpr float Beta4 = 1.70198817374094e-03f; + static constexpr float Beta2 = 1.16817656904453e-01f; + static constexpr float Beta0 = 9.93151921023180e-01f; + static constexpr float OneHalf = 0.5f; }; -MLAS_FORCEINLINE __m512 -MlasExpApproxAvx512( - __m512 Value - ) -{ - const __m512 Half = _mm512_set1_ps(SiluAvx512Constants::Half); - const __m512 One = _mm512_set1_ps(SiluAvx512Constants::One); - const __m512 Two = _mm512_set1_ps(SiluAvx512Constants::Two); - const __m512 Ln2 = _mm512_set1_ps(SiluAvx512Constants::Ln2); - const __m512 Log2EF = _mm512_set1_ps(SiluAvx512Constants::Log2EF); - const __m512 ExpLnFltMax = _mm512_set1_ps(SiluAvx512Constants::ExpLnFltMax); - const __m512 ExpLnFltMin = _mm512_set1_ps(SiluAvx512Constants::ExpLnFltMin); - const __m512 P1 = _mm512_set1_ps(SiluAvx512Constants::P1); - const __m512 P2 = _mm512_set1_ps(SiluAvx512Constants::P2); - const __m512 P3 = _mm512_set1_ps(SiluAvx512Constants::P3); - const __m512 P4 = _mm512_set1_ps(SiluAvx512Constants::P4); - const __m512 P5 = _mm512_set1_ps(SiluAvx512Constants::P5); - const __m512i ExponentBias = _mm512_set1_epi32(127); - - const __mmask16 UnderflowMask = _mm512_cmp_ps_mask(Value, ExpLnFltMin, _CMP_LT_OQ); - - Value = _mm512_min_ps(Value, ExpLnFltMax); - Value = _mm512_max_ps(Value, ExpLnFltMin); - - __m512 Fx = _mm512_fmadd_ps(Value, Log2EF, Half); - Fx = _mm512_floor_ps(Fx); - - const __m512 R = _mm512_fnmadd_ps(Fx, Ln2, Value); - - const __m512 NMinusOne = _mm512_sub_ps(Fx, One); - __m512i Exponent = _mm512_cvttps_epi32(NMinusOne); - Exponent = _mm512_add_epi32(Exponent, ExponentBias); - Exponent = _mm512_slli_epi32(Exponent, 23); - Exponent = _mm512_mask_mov_epi32(Exponent, UnderflowMask, _mm512_setzero_si512()); - const __m512 Pow2NMinusOne = _mm512_castsi512_ps(Exponent); - - __m512 Y = P5; - Y = _mm512_fmadd_ps(Y, R, P4); - Y = _mm512_fmadd_ps(Y, R, P3); - Y = _mm512_fmadd_ps(Y, R, P2); - Y = _mm512_fmadd_ps(Y, R, P1); - Y = _mm512_fmadd_ps(Y, R, One); - - Y = _mm512_mul_ps(Y, Pow2NMinusOne); - Y = _mm512_mul_ps(Y, Two); - return Y; -} - MLAS_FORCEINLINE __m512 MlasLogisticApproxAvx512( __m512 Value ) { - const __m512 One = _mm512_set1_ps(1.0f); - const __m512 Zero = _mm512_setzero_ps(); const __m512 LogisticLowerRange = _mm512_set1_ps(SiluAvx512Constants::LogisticLowerRange); const __m512 LogisticUpperRange = _mm512_set1_ps(SiluAvx512Constants::LogisticUpperRange); - const __m512 SignMask = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::SignBitMask)); - const __m512 PositiveMask = _mm512_castsi512_ps(_mm512_set1_epi32(SiluAvx512Constants::PositiveMask)); + const __m512 Alpha9 = _mm512_set1_ps(SiluAvx512Constants::Alpha9); + const __m512 Alpha7 = _mm512_set1_ps(SiluAvx512Constants::Alpha7); + const __m512 Alpha5 = _mm512_set1_ps(SiluAvx512Constants::Alpha5); + const __m512 Alpha3 = _mm512_set1_ps(SiluAvx512Constants::Alpha3); + const __m512 Alpha1 = _mm512_set1_ps(SiluAvx512Constants::Alpha1); + const __m512 Beta10 = _mm512_set1_ps(SiluAvx512Constants::Beta10); + const __m512 Beta8 = _mm512_set1_ps(SiluAvx512Constants::Beta8); + const __m512 Beta6 = _mm512_set1_ps(SiluAvx512Constants::Beta6); + const __m512 Beta4 = _mm512_set1_ps(SiluAvx512Constants::Beta4); + const __m512 Beta2 = _mm512_set1_ps(SiluAvx512Constants::Beta2); + const __m512 Beta0 = _mm512_set1_ps(SiluAvx512Constants::Beta0); + const __m512 OneHalf = _mm512_set1_ps(SiluAvx512Constants::OneHalf); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 One = _mm512_set1_ps(1.0f); - // Mirror MlasComputeLogistic clamping so the AVX512 SiLU path matches the - // existing MLAS semantics for large finite inputs and avoids x * 0 invalid - // behavior when the exp approximation underflows. + // Mirror MlasComputeLogistic by evaluating the same clamped rational + // approximation in-register and then multiplying by x for SiLU. const __m512 ClampedValue = _mm512_max_ps(_mm512_min_ps(Value, LogisticUpperRange), LogisticLowerRange); + const __m512 ValueSquared = _mm512_mul_ps(ClampedValue, ClampedValue); + + __m512 P = _mm512_fmadd_ps(ValueSquared, Alpha9, Alpha7); + P = _mm512_fmadd_ps(P, ValueSquared, Alpha5); + P = _mm512_fmadd_ps(P, ValueSquared, Alpha3); + P = _mm512_fmadd_ps(P, ValueSquared, Alpha1); + P = _mm512_mul_ps(P, ClampedValue); - const __m512 XAbs = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(ClampedValue), _mm512_castps_si512(PositiveMask))); - const __m512 XNeg = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(XAbs), _mm512_castps_si512(SignMask))); + __m512 Q = _mm512_fmadd_ps(ValueSquared, Beta10, Beta8); + Q = _mm512_fmadd_ps(Q, ValueSquared, Beta6); + Q = _mm512_fmadd_ps(Q, ValueSquared, Beta4); + Q = _mm512_fmadd_ps(Q, ValueSquared, Beta2); + Q = _mm512_fmadd_ps(Q, ValueSquared, Beta0); - const __m512 E = MlasExpApproxAvx512(XNeg); - const __m512 Y = _mm512_div_ps(E, _mm512_add_ps(E, One)); - const __m512 OneMinusY = _mm512_sub_ps(One, Y); - const __mmask16 NegativeMask = _mm512_cmp_ps_mask(ClampedValue, Zero, _CMP_LT_OQ); + __m512 Logistic = _mm512_add_ps(_mm512_div_ps(P, Q), OneHalf); + Logistic = _mm512_min_ps(_mm512_max_ps(Logistic, Zero), One); - return _mm512_mask_blend_ps(NegativeMask, OneMinusY, Y); + return Logistic; } MLAS_FORCEINLINE __m512 @@ -123,20 +88,14 @@ MlasComputeSiluVectorAvx512( __m512 X ) { - const __m512 PositiveInfinity = _mm512_set1_ps(std::numeric_limits::infinity()); - const __m512 NegativeInfinity = _mm512_set1_ps(-std::numeric_limits::infinity()); - __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X)); + // Preserve NaN payload/sign behavior explicitly because the clamped + // logistic approximation uses min/max operations that do not reliably + // propagate NaNs the same way as the existing MLAS SiLU semantics. const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); Result = _mm512_mask_mov_ps(Result, NaNMask, X); - const __mmask16 PositiveInfinityMask = _mm512_cmp_ps_mask(X, PositiveInfinity, _CMP_EQ_OQ); - Result = _mm512_mask_mov_ps(Result, PositiveInfinityMask, PositiveInfinity); - - const __mmask16 NegativeInfinityMask = _mm512_cmp_ps_mask(X, NegativeInfinity, _CMP_EQ_OQ); - Result = _mm512_mask_mov_ps(Result, NegativeInfinityMask, NegativeInfinity); - return Result; } From 5f844baf697db38c279b2cdcb5929e8ce211db1b Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 22 Mar 2026 19:04:14 -0700 Subject: [PATCH 30/37] Experiment --- .../lib/intrinsics/avx512/silu_avx512f.cpp | 57 ++++++++++++------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index 67c8224be74da..670142d88727d 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -38,11 +38,7 @@ struct SiluAvx512Constants { static constexpr float OneHalf = 0.5f; }; -MLAS_FORCEINLINE __m512 -MlasLogisticApproxAvx512( - __m512 Value - ) -{ +struct SiluAvx512BroadcastConstants { const __m512 LogisticLowerRange = _mm512_set1_ps(SiluAvx512Constants::LogisticLowerRange); const __m512 LogisticUpperRange = _mm512_set1_ps(SiluAvx512Constants::LogisticUpperRange); const __m512 Alpha9 = _mm512_set1_ps(SiluAvx512Constants::Alpha9); @@ -59,36 +55,44 @@ MlasLogisticApproxAvx512( const __m512 OneHalf = _mm512_set1_ps(SiluAvx512Constants::OneHalf); const __m512 Zero = _mm512_setzero_ps(); const __m512 One = _mm512_set1_ps(1.0f); +}; +MLAS_FORCEINLINE __m512 +MlasLogisticApproxAvx512( + __m512 Value, + const SiluAvx512BroadcastConstants& Constants + ) +{ // Mirror MlasComputeLogistic by evaluating the same clamped rational // approximation in-register and then multiplying by x for SiLU. - const __m512 ClampedValue = _mm512_max_ps(_mm512_min_ps(Value, LogisticUpperRange), LogisticLowerRange); + const __m512 ClampedValue = _mm512_max_ps(_mm512_min_ps(Value, Constants.LogisticUpperRange), Constants.LogisticLowerRange); const __m512 ValueSquared = _mm512_mul_ps(ClampedValue, ClampedValue); - __m512 P = _mm512_fmadd_ps(ValueSquared, Alpha9, Alpha7); - P = _mm512_fmadd_ps(P, ValueSquared, Alpha5); - P = _mm512_fmadd_ps(P, ValueSquared, Alpha3); - P = _mm512_fmadd_ps(P, ValueSquared, Alpha1); + __m512 P = _mm512_fmadd_ps(ValueSquared, Constants.Alpha9, Constants.Alpha7); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha5); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha3); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha1); P = _mm512_mul_ps(P, ClampedValue); - __m512 Q = _mm512_fmadd_ps(ValueSquared, Beta10, Beta8); - Q = _mm512_fmadd_ps(Q, ValueSquared, Beta6); - Q = _mm512_fmadd_ps(Q, ValueSquared, Beta4); - Q = _mm512_fmadd_ps(Q, ValueSquared, Beta2); - Q = _mm512_fmadd_ps(Q, ValueSquared, Beta0); + __m512 Q = _mm512_fmadd_ps(ValueSquared, Constants.Beta10, Constants.Beta8); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta6); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta4); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta2); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta0); - __m512 Logistic = _mm512_add_ps(_mm512_div_ps(P, Q), OneHalf); - Logistic = _mm512_min_ps(_mm512_max_ps(Logistic, Zero), One); + __m512 Logistic = _mm512_add_ps(_mm512_div_ps(P, Q), Constants.OneHalf); + Logistic = _mm512_min_ps(_mm512_max_ps(Logistic, Constants.Zero), Constants.One); return Logistic; } MLAS_FORCEINLINE __m512 MlasComputeSiluVectorAvx512( - __m512 X + __m512 X, + const SiluAvx512BroadcastConstants& Constants ) { - __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X)); + __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X, Constants)); // Preserve NaN payload/sign behavior explicitly because the clamped // logistic approximation uses min/max operations that do not reliably @@ -109,11 +113,22 @@ MlasSiluKernelAvx512F( size_t N ) { + const SiluAvx512BroadcastConstants Constants; size_t Offset = 0; + while (Offset + 32 <= N) { + const __m512 X0 = _mm512_loadu_ps(Input + Offset); + const __m512 X1 = _mm512_loadu_ps(Input + Offset + 16); + const __m512 Result0 = MlasComputeSiluVectorAvx512(X0, Constants); + const __m512 Result1 = MlasComputeSiluVectorAvx512(X1, Constants); + _mm512_storeu_ps(Output + Offset, Result0); + _mm512_storeu_ps(Output + Offset + 16, Result1); + Offset += 32; + } + while (Offset + 16 <= N) { const __m512 X = _mm512_loadu_ps(Input + Offset); - const __m512 Result = MlasComputeSiluVectorAvx512(X); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); _mm512_storeu_ps(Output + Offset, Result); Offset += 16; } @@ -121,7 +136,7 @@ MlasSiluKernelAvx512F( if (Offset < N) { const __mmask16 TailMask = static_cast<__mmask16>((1u << (N - Offset)) - 1u); const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input + Offset); - const __m512 Result = MlasComputeSiluVectorAvx512(X); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); _mm512_mask_storeu_ps(Output + Offset, TailMask, Result); } } From bf41c638a1ba5a612cc9f112e6b2ed858be0e1bd Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 22 Mar 2026 19:38:10 -0700 Subject: [PATCH 31/37] Experiment with Exact Gelu --- .../lib/intrinsics/avx512/gelu_avx512f.cpp | 143 +++++++++--------- 1 file changed, 74 insertions(+), 69 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp index 358b882b883be..4770e6c12a962 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -57,11 +57,29 @@ struct GeluAvx512Constants { static constexpr float ExpC = 1.25829120e+7f; }; -MLAS_FORCEINLINE __m512 -MlasGeluErfExpVectorAvx512( - __m512 Value - ) -{ +struct GeluAvx512BroadcastConstants { + const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(GeluAvx512Constants::SignBitMask)); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 InvSqrt2 = _mm512_set1_ps(GeluAvx512Constants::InvSqrt2); + const __m512 Half = _mm512_set1_ps(GeluAvx512Constants::Half); + const __m512 One = _mm512_set1_ps(GeluAvx512Constants::One); + const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange); + const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary); + const __m512 ErfSmallP0 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P0); + const __m512 ErfSmallP1 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P1); + const __m512 ErfSmallP2 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P2); + const __m512 ErfSmallP3 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P3); + const __m512 ErfSmallP4 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P4); + const __m512 ErfSmallP5MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P5_Minus_One); + const __m512 ErfBigP0 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P0); + const __m512 ErfBigP1 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P1); + const __m512 ErfBigP2 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P2); + const __m512 ErfBigP3 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P3); + const __m512 ErfBigP4 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P4); + const __m512 ErfBigP5 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P5); + const __m512 ErfBigP6MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P6_Minus_One); + const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne); + const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange); const __m512 ExpLog2Reciprocal = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Reciprocal); const __m512 ExpLog2Hi = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Hi); const __m512 ExpLog2Lo = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Lo); @@ -73,20 +91,27 @@ MlasGeluErfExpVectorAvx512( const __m512 ExpP5 = _mm512_set1_ps(GeluAvx512Constants::ExpP5); const __m512 ExpP6 = _mm512_set1_ps(GeluAvx512Constants::ExpP6); const __m512 ExpC = _mm512_set1_ps(GeluAvx512Constants::ExpC); +}; - __m512 R = _mm512_fmadd_ps(ExpLog2Reciprocal, Value, ExpC); - R = _mm512_sub_ps(R, ExpC); - - __m512 Fx = _mm512_fmadd_ps(R, ExpLog2Hi, Value); - Fx = _mm512_fmadd_ps(R, ExpLog2Lo, Fx); - - __m512 Y = ExpP0; - Y = _mm512_fmadd_ps(Y, Fx, ExpP1); - Y = _mm512_fmadd_ps(Y, Fx, ExpP2); - Y = _mm512_fmadd_ps(Y, Fx, ExpP3); - Y = _mm512_fmadd_ps(Y, Fx, ExpP4); - Y = _mm512_fmadd_ps(Y, Fx, ExpP5); - Y = _mm512_fmadd_ps(Y, Fx, ExpP6); +MLAS_FORCEINLINE __m512 +MlasGeluErfExpVectorAvx512( + __m512 Value, + const GeluAvx512BroadcastConstants& Constants + ) +{ + __m512 R = _mm512_fmadd_ps(Constants.ExpLog2Reciprocal, Value, Constants.ExpC); + R = _mm512_sub_ps(R, Constants.ExpC); + + __m512 Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Hi, Value); + Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Lo, Fx); + + __m512 Y = Constants.ExpP0; + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP1); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP2); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP3); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP4); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP5); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP6); Y = _mm512_scalef_ps(Y, R); return Y; @@ -94,58 +119,39 @@ MlasGeluErfExpVectorAvx512( MLAS_FORCEINLINE __m512 MlasGeluErfAvx512( - __m512 Value + __m512 Value, + const GeluAvx512BroadcastConstants& Constants ) { - const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(GeluAvx512Constants::SignBitMask)); - const __m512 Zero = _mm512_setzero_ps(); - const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange); - const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary); - const __m512 ErfSmallP0 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P0); - const __m512 ErfSmallP1 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P1); - const __m512 ErfSmallP2 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P2); - const __m512 ErfSmallP3 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P3); - const __m512 ErfSmallP4 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P4); - const __m512 ErfSmallP5MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P5_Minus_One); - const __m512 ErfBigP0 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P0); - const __m512 ErfBigP1 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P1); - const __m512 ErfBigP2 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P2); - const __m512 ErfBigP3 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P3); - const __m512 ErfBigP4 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P4); - const __m512 ErfBigP5 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P5); - const __m512 ErfBigP6MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P6_Minus_One); - const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne); - const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange); - - const __m512 SignMask = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(NegZero))); - __m512 AbsValue = _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(NegZero), _mm512_castps_si512(Value))); - AbsValue = _mm512_min_ps(ErfUpperAbsRange, AbsValue); + const __m512 SignMask = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(Constants.NegZero))); + __m512 AbsValue = _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(Constants.NegZero), _mm512_castps_si512(Value))); + AbsValue = _mm512_min_ps(Constants.ErfUpperAbsRange, AbsValue); const __m512 SquareValue = _mm512_mul_ps(AbsValue, AbsValue); - __m512 SmallResult = ErfSmallP0; - SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP1); - SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP2); - SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP3); - SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP4); - SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, ErfSmallP5MinusOne); + __m512 SmallResult = Constants.ErfSmallP0; + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP1); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP2); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP3); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP4); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP5MinusOne); SmallResult = _mm512_fmadd_ps(SmallResult, AbsValue, AbsValue); - const __mmask16 SplitMask = _mm512_cmp_ps_mask(AbsValue, ErfSplitBoundary, _CMP_GT_OQ); - const __m512 BigInput = _mm512_mask_blend_ps(SplitMask, Zero, AbsValue); + const __mmask16 SplitMask = _mm512_cmp_ps_mask(AbsValue, Constants.ErfSplitBoundary, _CMP_GT_OQ); + const __m512 BigInput = _mm512_mask_blend_ps(SplitMask, Constants.Zero, AbsValue); - __m512 BigResult = ErfBigP0; - BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP1); - BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP2); - BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP3); - BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP4); - BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP5); - BigResult = _mm512_fmadd_ps(BigResult, BigInput, ErfBigP6MinusOne); + __m512 BigResult = Constants.ErfBigP0; + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP1); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP2); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP3); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP4); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP5); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP6MinusOne); BigResult = _mm512_fmadd_ps(BigResult, BigInput, BigInput); - BigResult = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(BigResult), _mm512_castps_si512(NegZero))); - BigResult = _mm512_max_ps(ExpLowerRange, BigResult); - BigResult = _mm512_sub_ps(ErfOne, MlasGeluErfExpVectorAvx512(BigResult)); + BigResult = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(BigResult), _mm512_castps_si512(Constants.NegZero))); + BigResult = _mm512_max_ps(Constants.ExpLowerRange, BigResult); + BigResult = _mm512_sub_ps(Constants.ErfOne, MlasGeluErfExpVectorAvx512(BigResult, Constants)); __m512 Result = _mm512_mask_blend_ps(SplitMask, SmallResult, BigResult); Result = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(Result), _mm512_castps_si512(SignMask))); @@ -154,15 +160,13 @@ MlasGeluErfAvx512( MLAS_FORCEINLINE __m512 MlasComputeGeluVectorExactAvx512( - __m512 X + __m512 X, + const GeluAvx512BroadcastConstants& Constants ) { - const __m512 InvSqrt2 = _mm512_set1_ps(GeluAvx512Constants::InvSqrt2); - const __m512 Half = _mm512_set1_ps(GeluAvx512Constants::Half); - const __m512 One = _mm512_set1_ps(GeluAvx512Constants::One); - const __m512 ErfInput = _mm512_mul_ps(X, InvSqrt2); - const __m512 ErfValue = MlasGeluErfAvx512(ErfInput); - return _mm512_mul_ps(_mm512_mul_ps(Half, X), _mm512_add_ps(ErfValue, One)); + const __m512 ErfInput = _mm512_mul_ps(X, Constants.InvSqrt2); + const __m512 ErfValue = MlasGeluErfAvx512(ErfInput, Constants); + return _mm512_mul_ps(_mm512_mul_ps(Constants.Half, X), _mm512_add_ps(ErfValue, Constants.One)); } void @@ -172,9 +176,10 @@ MlasGeluKernelAvx512FExactImpl( size_t N ) { + const GeluAvx512BroadcastConstants Constants; while (N >= 16) { const __m512 X = _mm512_loadu_ps(Input); - const __m512 Result = MlasComputeGeluVectorExactAvx512(X); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); _mm512_storeu_ps(Output, Result); @@ -186,7 +191,7 @@ MlasGeluKernelAvx512FExactImpl( if (N > 0) { const __mmask16 TailMask = __mmask16((1u << static_cast(N)) - 1u); const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input); - const __m512 Result = MlasComputeGeluVectorExactAvx512(X); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); _mm512_mask_storeu_ps(Output, TailMask, Result); } From 28bf32e28b4408fa7d918dffaaf96ffac6ce1c81 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 22 Mar 2026 21:23:56 -0700 Subject: [PATCH 32/37] Copilot comments --- onnxruntime/test/mlas/bench/bench_transcendental.cpp | 6 ++++-- .../test/mlas/unittest/test_transcendental_avx512.cpp | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index f7e1f99dbf655..cc112ccb1498e 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -91,8 +91,9 @@ void RunDispatchedUnaryBenchmark(benchmark::State& state, benchmark::ClobberMemory(); } + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * path_info.bytes_per_element; state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); - state.SetBytesProcessed(static_cast(state.iterations()) * static_cast(n * sizeof(float) * path_info.bytes_per_element)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); } template @@ -113,8 +114,9 @@ void RunUnfusedUnaryBenchmark(benchmark::State& state, benchmark::ClobberMemory(); } + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * bytes_per_element; state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); - state.SetBytesProcessed(static_cast(state.iterations()) * static_cast(n * sizeof(float) * bytes_per_element)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); } static void UnaryKernelArgs(benchmark::internal::Benchmark* b) { diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 35c989e49d7b1..052f6cedbc0e8 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -166,6 +166,14 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { << ", avx512=" << avx512_output[i] << ", generic=" << generic_output[i] << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public/API Gelu dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); } } } From 947652209114496d1d80db925490e2901cbe219d Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 23 Mar 2026 12:35:41 -0700 Subject: [PATCH 33/37] PR feedback --- onnxruntime/core/mlas/lib/gelu.cpp | 14 +++++++++----- .../mlas/lib/intrinsics/avx512/gelu_avx512f.cpp | 16 ++++++++++++---- onnxruntime/core/mlas/lib/mlasi.h | 6 +++--- onnxruntime/core/mlas/lib/platform.cpp | 4 ++-- onnxruntime/core/mlas/lib/silu.cpp | 9 +++++++-- onnxruntime/core/providers/cpu/tensor/gelu.cc | 2 ++ .../test/mlas/bench/bench_transcendental.cpp | 6 +++--- .../mlas/unittest/test_transcendental_avx512.cpp | 16 ++++++++-------- 8 files changed, 46 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp index d35d90208f244..dc25611652c77 100644 --- a/onnxruntime/core/mlas/lib/gelu.cpp +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -25,15 +25,16 @@ constexpr float kInvSqrt2 = 0.70710678118654752440f; void MLASCALL -MlasGeluKernel( +MlasGeluErfKernel( const float* Input, float* Output, size_t N ) { - // This kernel is not buffer alias safe, as the computation is not elementwise. + // This kernel is not buffer alias safe because it is implemented in + // multiple passes: first scale Input into Output, then apply erf in place, + // and finally combine that intermediate with the original Input values. // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). - // for (size_t i = 0; i < N; ++i) { Output[i] = Input[i] * kInvSqrt2; } @@ -54,8 +55,11 @@ MlasComputeGeluErf( ) { #if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().GeluKernelRoutine(Input, Output, N); + // TODO: Add an intermediate fused AVX2/FMA3 GELU(erf) path on AMD64. + // Today the dispatch jumps from the generic multi-pass implementation to + // AVX512F, so non-AVX512 x64 machines fall back to the generic kernel. + GetMlasPlatform().GeluErfKernelRoutine(Input, Output, N); #else - MlasGeluKernel(Input, Output, N); + MlasGeluErfKernel(Input, Output, N); #endif } diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp index 4770e6c12a962..4a9f3a100ed65 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -166,11 +166,19 @@ MlasComputeGeluVectorExactAvx512( { const __m512 ErfInput = _mm512_mul_ps(X, Constants.InvSqrt2); const __m512 ErfValue = MlasGeluErfAvx512(ErfInput, Constants); - return _mm512_mul_ps(_mm512_mul_ps(Constants.Half, X), _mm512_add_ps(ErfValue, Constants.One)); + __m512 Result = _mm512_mul_ps(_mm512_mul_ps(Constants.Half, X), _mm512_add_ps(ErfValue, Constants.One)); + + // Preserve NaN payload/sign behavior explicitly because the erf + // approximation uses min/max style range limiting that is not guaranteed to + // preserve NaNs the same way as the existing MLAS GELU semantics. + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + return Result; } void -MlasGeluKernelAvx512FExactImpl( +MlasGeluErfKernelAvx512FExactImpl( const float* Input, float* Output, size_t N @@ -201,11 +209,11 @@ MlasGeluKernelAvx512FExactImpl( void MLASCALL -MlasGeluKernelAvx512F( +MlasGeluErfKernelAvx512F( const float* Input, float* Output, size_t N ) { - MlasGeluKernelAvx512FExactImpl(Input, Output, N); + MlasGeluErfKernelAvx512FExactImpl(Input, Output, N); } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index d860bb76e1a41..b667ca78b37d7 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1096,7 +1096,7 @@ extern "C" { #endif MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasLogisticKernel; @@ -1128,7 +1128,7 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluKernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernelAvx512F; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernelAvx512F; #endif @@ -1481,7 +1481,7 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif #if defined(MLAS_TARGET_AMD64) - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index a9a52a7ff75aa..4b83309fd3347 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -283,7 +283,7 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelSse; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; - this->GeluKernelRoutine = MlasGeluKernel; + this->GeluErfKernelRoutine = MlasGeluErfKernel; this->LogisticKernelRoutine = MlasLogisticKernel; this->SiluKernelRoutine = MlasSiluKernel; this->TanhKernelRoutine = MlasTanhKernel; @@ -461,7 +461,7 @@ Return Value: // if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { - this->GeluKernelRoutine = MlasGeluKernelAvx512F; + this->GeluErfKernelRoutine = MlasGeluErfKernelAvx512F; this->SiluKernelRoutine = MlasSiluKernelAvx512F; this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp index a71d86fc38d5d..96686e4bdf1da 100644 --- a/onnxruntime/core/mlas/lib/silu.cpp +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -24,8 +24,10 @@ MlasSiluKernel( size_t N ) { - // This kernel is not buffer alias safe, as the computation is not elementwise. - // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). + // This kernel is not buffer alias safe because it is implemented in two + // passes: first compute logistic(Input) into Output, then multiply that + // intermediate by the original Input values. Callers must guarantee that + // Input and Output do not overlap (see mlas.h for aliasing requirements). MlasComputeLogistic(Input, Output, N); MlasEltwiseMul(Input, Output, Output, N); } @@ -39,6 +41,9 @@ MlasComputeSilu( ) { #if defined(MLAS_TARGET_AMD64) + // TODO: Add an intermediate fused AVX2/FMA3 SiLU path on AMD64. Today the + // dispatch jumps from the generic two-pass implementation to AVX512F, so + // non-AVX512 x64 machines fall back to the generic kernel. GetMlasPlatform().SiluKernelRoutine(Input, Output, N); #else MlasSiluKernel(Input, Output, N); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index 28bb63561fac6..e34af83d1f29e 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -88,6 +88,8 @@ Status Gelu::Compute(OpKernelContext* context) const { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); + // MlasComputeGeluErf requires distinct input/output buffers. This + // call uses disjoint slices from the input and output tensors. MlasComputeGeluErf(p_input, p_output, narrow(count)); }, 0); diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index cc112ccb1498e..812c97583cb73 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -40,9 +40,9 @@ DispatchedUnaryPathInfo GetSiluDispatchPathInfo() { return {kSiluUnfusedBytesPerElement, "generic_fallback"}; } -DispatchedUnaryPathInfo GetGeluDispatchPathInfo() { +DispatchedUnaryPathInfo GetGeluErfDispatchPathInfo() { #if defined(MLAS_TARGET_AMD64) - if (GetMlasPlatform().GeluKernelRoutine == MlasGeluKernelAvx512F) { + if (GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F) { return {kFusedBytesPerElement, "avx512_fused"}; } #endif @@ -157,7 +157,7 @@ void BM_GeluErfDispatchExact(benchmark::State& state) { }, kGeluMinValue, kGeluMaxValue, - GetGeluDispatchPathInfo()); + GetGeluErfDispatchPathInfo()); } void BM_GeluErfUnfusedExact(benchmark::State& state) { diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 052f6cedbc0e8..64a99674eebed 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -28,8 +28,8 @@ constexpr std::array kLongTestSizes = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; -bool IsGeluAvx512Dispatched() { - return GetMlasPlatform().GeluKernelRoutine == MlasGeluKernelAvx512F; +bool IsGeluErfAvx512Dispatched() { + return GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F; } bool IsSiluAvx512Dispatched() { @@ -124,7 +124,7 @@ void FillInput(float* input, size_t n, float minimum_value, float maximum_value, } } -class MlasComputeGeluAvx512Test : public MlasTestBase { +class MlasComputeGeluErfAvx512Test : public MlasTestBase { private: MatrixGuardBuffer input_buffer_; MatrixGuardBuffer generic_output_buffer_; @@ -132,7 +132,7 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { MatrixGuardBuffer avx512_output_buffer_; void ExecuteCommon(const std::vector& sizes, size_t iterations) { - if (!IsGeluAvx512Dispatched()) { + if (!IsGeluErfAvx512Dispatched()) { GTEST_SKIP() << "AVX512F GELU dispatch is not available on this machine."; } @@ -146,9 +146,9 @@ class MlasComputeGeluAvx512Test : public MlasTestBase { FillInput(input, size, kGeluMinValue, kGeluMaxValue, GetGeluSpecialValues(), static_cast(size * 131u + iteration * 977u + 17u)); - MlasGeluKernel(input, generic_output, size); + MlasGeluErfKernel(input, generic_output, size); MlasComputeGeluErf(input, public_output, size); - MlasGeluKernelAvx512F(input, avx512_output, size); + MlasGeluErfKernelAvx512F(input, avx512_output, size); for (size_t i = 0; i < size; ++i) { ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], @@ -267,10 +267,10 @@ class MlasComputeSiluAvx512Test : public MlasTestBase { static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { - count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); count += MlasDirectShortExecuteTests::RegisterShortExecute(); } else { - count += MlasLongExecuteTests::RegisterLongExecute(); + count += MlasLongExecuteTests::RegisterLongExecute(); count += MlasLongExecuteTests::RegisterLongExecute(); } return count; From cfa0bf233450974b57a4f43d9b1a27d90ac209aa Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 23 Mar 2026 12:36:46 -0700 Subject: [PATCH 34/37] Nit --- onnxruntime/test/mlas/bench/bench_transcendental.cpp | 2 +- .../test/mlas/unittest/test_transcendental_avx512.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index 812c97583cb73..f7e461c29843a 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -161,7 +161,7 @@ void BM_GeluErfDispatchExact(benchmark::State& state) { } void BM_GeluErfUnfusedExact(benchmark::State& state) { - // Unfused exact GELU baseline: scale by 1/sqrt(2), run erf, then apply the + // Unfused exact GELU(erf) baseline: scale by 1/sqrt(2), run erf, then apply the // final 0.5 * x * (erf(x / sqrt(2)) + 1) transform in a separate pass. RunUnfusedUnaryBenchmark( state, diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp index 64a99674eebed..e87768ce3e660 100644 --- a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -133,7 +133,7 @@ class MlasComputeGeluErfAvx512Test : public MlasTestBase { void ExecuteCommon(const std::vector& sizes, size_t iterations) { if (!IsGeluErfAvx512Dispatched()) { - GTEST_SKIP() << "AVX512F GELU dispatch is not available on this machine."; + GTEST_SKIP() << "AVX512F GELU(erf) dispatch is not available on this machine."; } for (size_t size : sizes) { @@ -153,7 +153,7 @@ class MlasComputeGeluErfAvx512Test : public MlasTestBase { for (size_t i = 0; i < size; ++i) { ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) - << "Public Gelu mismatch at index " << i << " of " << size + << "Public GELU(erf) mismatch at index " << i << " of " << size << ", input=" << input[i] << ", public=" << public_output[i] << ", generic=" << generic_output[i] @@ -161,7 +161,7 @@ class MlasComputeGeluErfAvx512Test : public MlasTestBase { ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) - << "Gelu mismatch at index " << i << " of " << size + << "GELU(erf) mismatch at index " << i << " of " << size << ", input=" << input[i] << ", avx512=" << avx512_output[i] << ", generic=" << generic_output[i] @@ -169,7 +169,7 @@ class MlasComputeGeluErfAvx512Test : public MlasTestBase { ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) - << "Public/API Gelu dispatch mismatch at index " << i << " of " << size + << "Public/API GELU(erf) dispatch mismatch at index " << i << " of " << size << ", input=" << input[i] << ", avx512=" << avx512_output[i] << ", public=" << public_output[i] From d459cf72e2cb78cb1e9c6ff83af1612f20e68cee Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 23 Mar 2026 13:38:16 -0700 Subject: [PATCH 35/37] Fix alignment --- onnxruntime/core/mlas/lib/platform.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 4b83309fd3347..eccde79848e61 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -283,7 +283,7 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelSse; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; - this->GeluErfKernelRoutine = MlasGeluErfKernel; + this->GeluErfKernelRoutine = MlasGeluErfKernel; this->LogisticKernelRoutine = MlasLogisticKernel; this->SiluKernelRoutine = MlasSiluKernel; this->TanhKernelRoutine = MlasTanhKernel; From 8652b3d6e8352863b00e27a0b4d9d046cf20f3c6 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 23 Mar 2026 13:46:23 -0700 Subject: [PATCH 36/37] Update onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp index 670142d88727d..7e8424d94827a 100644 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -15,8 +15,6 @@ Module Name: --*/ -#include - #include "mlasi.h" namespace { From fcf1f1a39d8a3347839eec85c4ba411e51c03eaa Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 23 Mar 2026 18:58:18 -0700 Subject: [PATCH 37/37] Alignment --- onnxruntime/core/mlas/lib/mlasi.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b667ca78b37d7..0dab8e41f25cd 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1481,7 +1481,7 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif #if defined(MLAS_TARGET_AMD64) - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine;