Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
101c0f5
Add fused Silu and Gelu kernels for AVX512
hariharans29 Mar 17, 2026
8b3c23a
Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp
hariharans29 Mar 17, 2026
26ed025
Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp
hariharans29 Mar 17, 2026
3bf1f01
Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp
hariharans29 Mar 17, 2026
e31395d
Slight adjustments in the code
hariharans29 Mar 17, 2026
0822fc9
Merge branch 'hari/fused_silu_avx512' of https://github.com/microsoft…
hariharans29 Mar 17, 2026
9a793a3
More build changes
hariharans29 Mar 17, 2026
2cd17b1
More changes
hariharans29 Mar 17, 2026
d99bfd8
Potential fix for pull request finding
hariharans29 Mar 17, 2026
275b69a
Potential fix for pull request finding
hariharans29 Mar 17, 2026
3a80418
Fix ARM build + Copilot suggestions
hariharans29 Mar 17, 2026
a3f7033
Update onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp
hariharans29 Mar 17, 2026
08157ab
Remove Minimax approx + address PR feedback
hariharans29 Mar 21, 2026
7d23425
Update onnxruntime/test/providers/cpu/activation/activation_op_test.cc
hariharans29 Mar 21, 2026
9f9bac7
Update onnxruntime/test/mlas/bench/bench_transcendental.cpp
hariharans29 Mar 21, 2026
dc63f2f
Copilot comments
hariharans29 Mar 21, 2026
b8ada76
Update onnxruntime/core/mlas/lib/silu.cpp
hariharans29 Mar 21, 2026
ad7a0c2
Update onnxruntime/core/mlas/lib/gelu.cpp
hariharans29 Mar 21, 2026
cbffffa
Adjust comment
hariharans29 Mar 21, 2026
50980aa
Copilot comments
hariharans29 Mar 21, 2026
da55583
Copilot comments
hariharans29 Mar 21, 2026
bb3cd55
Update onnxruntime/test/mlas/bench/bench_transcendental.cpp
hariharans29 Mar 21, 2026
f6a22fc
Adjust AVX512 path to match generic path for special values
hariharans29 Mar 21, 2026
41bfd24
Merge branch 'hari/fused_silu_avx512' of https://github.com/microsoft…
hariharans29 Mar 21, 2026
9a44721
a
hariharans29 Mar 21, 2026
1f4f3a1
Revert accidental change
hariharans29 Mar 21, 2026
0f7db94
Update onnxruntime/core/mlas/lib/silu.cpp
hariharans29 Mar 21, 2026
84cf795
Update onnxruntime/core/mlas/lib/gelu.cpp
hariharans29 Mar 21, 2026
7c84be7
Copilot comment
hariharans29 Mar 22, 2026
c9a549c
Copilot comment
hariharans29 Mar 22, 2026
05ca9c2
Rework Silu
hariharans29 Mar 22, 2026
5f844ba
Experiment
hariharans29 Mar 23, 2026
bf41c63
Experiment with Exact Gelu
hariharans29 Mar 23, 2026
bd93543
Merge remote-tracking branch 'origin/main' into hari/fused_silu_avx512
hariharans29 Mar 23, 2026
28bf32e
Copilot comments
hariharans29 Mar 23, 2026
9476522
PR feedback
hariharans29 Mar 23, 2026
cfa0bf2
Nit
hariharans29 Mar 23, 2026
d459cf7
Fix alignment
hariharans29 Mar 23, 2026
8652b3d
Update onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp
hariharans29 Mar 23, 2026
fcf1f1a
Alignment
hariharans29 Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/eltwise.h
${MLAS_SRC_DIR}/eltwise.cpp
${MLAS_SRC_DIR}/erf.cpp
${MLAS_SRC_DIR}/silu.cpp
${MLAS_SRC_DIR}/gelu.cpp
${MLAS_SRC_DIR}/compute.cpp
${MLAS_SRC_DIR}/dequantize.cpp
${MLAS_SRC_DIR}/quantize.cpp
Expand Down Expand Up @@ -201,6 +203,14 @@ function(setup_mlas_source_for_windows)
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2")

set(mlas_platform_srcs_avx512
${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
)

set_source_files_properties(${mlas_platform_srcs_avx512} PROPERTIES COMPILE_FLAGS "/arch:AVX512")

target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/dgemm.cpp
${mlas_platform_srcs_avx}
Expand All @@ -212,7 +222,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
${mlas_platform_srcs_avx512}
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
Expand Down Expand Up @@ -764,6 +774,8 @@ endif()
${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S
${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S
${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S
${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f")
Expand Down
28 changes: 14 additions & 14 deletions onnxruntime/contrib_ops/cpu/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,22 @@
T* p_output = output_data + start;
int64_t count = std::min(length_per_task, elem_count - start);

if (alpha_ != 1.0f) {
// TODO: Consider vectorizing this scalar multiplication.
// It needs exposing a new API in MLAS to take in a scalar
// that will be used in the elementwise multiplication.
// Estimate the cost-benefit tradeoff before proceeding
// with that optimization.
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * alpha_;
}

MlasComputeLogistic(p_output, p_output, onnxruntime::narrow<size_t>(count));
} else {
// SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f
MlasComputeLogistic(p_input, p_output, onnxruntime::narrow<size_t>(count));
if (alpha_ == 1.0f) {
MlasComputeSilu(p_input, p_output, onnxruntime::narrow<size_t>(count));
return;
}

// TODO: Consider vectorizing this scalar multiplication.

Check warning on line 86 in onnxruntime/contrib_ops/cpu/activations.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/activations.h:86: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// It needs exposing a new API in MLAS to take in a scalar
// that will be used in the elementwise multiplication.
// Estimate the cost-benefit tradeoff before proceeding
// with that optimization.
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * alpha_;
}

MlasComputeLogistic(p_output, p_output, onnxruntime::narrow<size_t>(count));

MlasEltwiseMul<float>(p_input, p_output, p_output, onnxruntime::narrow<size_t>(count));
},
0);
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,30 @@ MlasComputeErf(
size_t N
);

//
// Note: The Input and Output buffers for MlasComputeGeluErf must not overlap.
// In-place operation (e.g., passing the same buffer for both parameters) is unsupported.
//
void
MLASCALL
MlasComputeGeluErf(
const float* Input,
float* Output,
size_t N
);

//
// Note: The Input and Output buffers for MlasComputeSilu must not overlap.
// In-place operation (e.g., passing the same buffer for both parameters) is unsupported.
//
void
MLASCALL
MlasComputeSilu(
const float* Input,
float* Output,
size_t N
);
Comment thread
hariharans29 marked this conversation as resolved.

template <typename T>
void
MLASCALL
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/core/mlas/lib/gelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

gelu.cpp

Abstract:

This module implements routines to compute the exact Gelu function.

--*/

#include "mlasi.h"

namespace {

constexpr float kInvSqrt2 = 0.70710678118654752440f;

} // namespace


void
MLASCALL
MlasGeluErfKernel(
const float* Input,
float* Output,
size_t N
)
{
// This kernel is not buffer alias safe because it is implemented in
// multiple passes: first scale Input into Output, then apply erf in place,
// and finally combine that intermediate with the original Input values.
// Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements).
for (size_t i = 0; i < N; ++i) {
Output[i] = Input[i] * kInvSqrt2;
}
Comment thread
hariharans29 marked this conversation as resolved.

MlasComputeErf(Output, Output, N);

for (size_t i = 0; i < N; ++i) {
Output[i] = 0.5f * Input[i] * (Output[i] + 1.0f);
}
}

void
MLASCALL
MlasComputeGeluErf(
const float* Input,
float* Output,
size_t N
)
{
#if defined(MLAS_TARGET_AMD64)
// TODO: Add an intermediate fused AVX2/FMA3 GELU(erf) path on AMD64.
// Today the dispatch jumps from the generic multi-pass implementation to
// AVX512F, so non-AVX512 x64 machines fall back to the generic kernel.
GetMlasPlatform().GeluErfKernelRoutine(Input, Output, N);
#else
MlasGeluErfKernel(Input, Output, N);
#endif
}
219 changes: 219 additions & 0 deletions onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

gelu_avx512f.cpp

Abstract:

This module implements routines to compute exact Gelu with AVX512F
intrinsics.

--*/

#include <cstdint>

#include "mlasi.h"

namespace {

struct GeluAvx512Constants {
static constexpr int32_t SignBitMask = INT32_MIN;
static constexpr float InvSqrt2 = 0.70710678118654752440f;
static constexpr float Half = 0.5f;
static constexpr float One = 1.0f;

static constexpr float ErfUpperAbsRange = 3.925f;
static constexpr float ErfSplitBoundary = 0.921875f;
static constexpr float ErfSMALL_P0 = -5.99104969e-4f;
static constexpr float ErfSMALL_P1 = 4.99339588e-3f;
static constexpr float ErfSMALL_P2 = -2.67667342e-2f;
static constexpr float ErfSMALL_P3 = 1.12818025e-1f;
static constexpr float ErfSMALL_P4 = -3.76124859e-1f;
static constexpr float ErfSMALL_P5_Minus_One = 1.28379151e-1f;
static constexpr float ErfBIG_P0 = 1.72948930e-5f;
static constexpr float ErfBIG_P1 = -3.83208680e-4f;
static constexpr float ErfBIG_P2 = 3.88393435e-3f;
static constexpr float ErfBIG_P3 = -2.42545605e-2f;
static constexpr float ErfBIG_P4 = 1.06777847e-1f;
static constexpr float ErfBIG_P5 = 6.34846687e-1f;
static constexpr float ErfBIG_P6_Minus_One = 1.28717512e-1f;
static constexpr float ErfOne = 1.0f;
static constexpr float ExpLowerRange = -88.3762626647949f;
static constexpr float ExpLog2Reciprocal = 1.44269504088896341f;
static constexpr float ExpLog2Hi = -6.93145752e-1f;
static constexpr float ExpLog2Lo = -1.42860677e-6f;
static constexpr float ExpP0 = 1.38319808e-3f;
static constexpr float ExpP1 = 8.37550033e-3f;
static constexpr float ExpP2 = 4.16689515e-2f;
static constexpr float ExpP3 = 1.66664466e-1f;
static constexpr float ExpP4 = 4.99999851e-1f;
static constexpr float ExpP5 = 1.0f;
static constexpr float ExpP6 = 1.0f;
static constexpr float ExpC = 1.25829120e+7f;
};

struct GeluAvx512BroadcastConstants {
const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(GeluAvx512Constants::SignBitMask));
const __m512 Zero = _mm512_setzero_ps();
const __m512 InvSqrt2 = _mm512_set1_ps(GeluAvx512Constants::InvSqrt2);
const __m512 Half = _mm512_set1_ps(GeluAvx512Constants::Half);
const __m512 One = _mm512_set1_ps(GeluAvx512Constants::One);
const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange);
const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary);
const __m512 ErfSmallP0 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P0);
const __m512 ErfSmallP1 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P1);
const __m512 ErfSmallP2 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P2);
const __m512 ErfSmallP3 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P3);
const __m512 ErfSmallP4 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P4);
const __m512 ErfSmallP5MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P5_Minus_One);
const __m512 ErfBigP0 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P0);
const __m512 ErfBigP1 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P1);
const __m512 ErfBigP2 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P2);
const __m512 ErfBigP3 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P3);
const __m512 ErfBigP4 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P4);
const __m512 ErfBigP5 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P5);
const __m512 ErfBigP6MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P6_Minus_One);
const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne);
const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange);
const __m512 ExpLog2Reciprocal = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Reciprocal);
const __m512 ExpLog2Hi = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Hi);
const __m512 ExpLog2Lo = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Lo);
const __m512 ExpP0 = _mm512_set1_ps(GeluAvx512Constants::ExpP0);
const __m512 ExpP1 = _mm512_set1_ps(GeluAvx512Constants::ExpP1);
const __m512 ExpP2 = _mm512_set1_ps(GeluAvx512Constants::ExpP2);
const __m512 ExpP3 = _mm512_set1_ps(GeluAvx512Constants::ExpP3);
const __m512 ExpP4 = _mm512_set1_ps(GeluAvx512Constants::ExpP4);
const __m512 ExpP5 = _mm512_set1_ps(GeluAvx512Constants::ExpP5);
const __m512 ExpP6 = _mm512_set1_ps(GeluAvx512Constants::ExpP6);
const __m512 ExpC = _mm512_set1_ps(GeluAvx512Constants::ExpC);
};

MLAS_FORCEINLINE __m512
MlasGeluErfExpVectorAvx512(
__m512 Value,
const GeluAvx512BroadcastConstants& Constants
)
{
__m512 R = _mm512_fmadd_ps(Constants.ExpLog2Reciprocal, Value, Constants.ExpC);
R = _mm512_sub_ps(R, Constants.ExpC);

__m512 Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Hi, Value);
Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Lo, Fx);

__m512 Y = Constants.ExpP0;
Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP1);
Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP2);
Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP3);
Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP4);
Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP5);
Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP6);
Y = _mm512_scalef_ps(Y, R);

return Y;
}

MLAS_FORCEINLINE __m512
MlasGeluErfAvx512(
__m512 Value,
const GeluAvx512BroadcastConstants& Constants
)
{
const __m512 SignMask = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(Constants.NegZero)));
__m512 AbsValue = _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(Constants.NegZero), _mm512_castps_si512(Value)));
AbsValue = _mm512_min_ps(Constants.ErfUpperAbsRange, AbsValue);

const __m512 SquareValue = _mm512_mul_ps(AbsValue, AbsValue);

__m512 SmallResult = Constants.ErfSmallP0;
SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP1);
SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP2);
SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP3);
SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP4);
SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP5MinusOne);
SmallResult = _mm512_fmadd_ps(SmallResult, AbsValue, AbsValue);

const __mmask16 SplitMask = _mm512_cmp_ps_mask(AbsValue, Constants.ErfSplitBoundary, _CMP_GT_OQ);
const __m512 BigInput = _mm512_mask_blend_ps(SplitMask, Constants.Zero, AbsValue);

__m512 BigResult = Constants.ErfBigP0;
BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP1);
BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP2);
BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP3);
BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP4);
BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP5);
BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP6MinusOne);
BigResult = _mm512_fmadd_ps(BigResult, BigInput, BigInput);

BigResult = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(BigResult), _mm512_castps_si512(Constants.NegZero)));
BigResult = _mm512_max_ps(Constants.ExpLowerRange, BigResult);
BigResult = _mm512_sub_ps(Constants.ErfOne, MlasGeluErfExpVectorAvx512(BigResult, Constants));

__m512 Result = _mm512_mask_blend_ps(SplitMask, SmallResult, BigResult);
Result = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(Result), _mm512_castps_si512(SignMask)));
return Result;
}

MLAS_FORCEINLINE __m512
MlasComputeGeluVectorExactAvx512(
__m512 X,
const GeluAvx512BroadcastConstants& Constants
)
{
const __m512 ErfInput = _mm512_mul_ps(X, Constants.InvSqrt2);
const __m512 ErfValue = MlasGeluErfAvx512(ErfInput, Constants);
__m512 Result = _mm512_mul_ps(_mm512_mul_ps(Constants.Half, X), _mm512_add_ps(ErfValue, Constants.One));

// Preserve NaN payload/sign behavior explicitly because the erf
// approximation uses min/max style range limiting that is not guaranteed to
// preserve NaNs the same way as the existing MLAS GELU semantics.
const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q);
Result = _mm512_mask_mov_ps(Result, NaNMask, X);

return Result;
}

void
MlasGeluErfKernelAvx512FExactImpl(
const float* Input,
float* Output,
size_t N
)
{
const GeluAvx512BroadcastConstants Constants;
while (N >= 16) {
const __m512 X = _mm512_loadu_ps(Input);
const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants);

_mm512_storeu_ps(Output, Result);

Input += 16;
Output += 16;
N -= 16;
}

if (N > 0) {
const __mmask16 TailMask = __mmask16((1u << static_cast<unsigned>(N)) - 1u);
const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input);
const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants);

_mm512_mask_storeu_ps(Output, TailMask, Result);
}
}

} // namespace

void
MLASCALL
MlasGeluErfKernelAvx512F(
const float* Input,
float* Output,
size_t N
)
{
MlasGeluErfKernelAvx512FExactImpl(Input, Output, N);
}
Loading
Loading