diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 0156e46b86bc4..4f75a8b105ec2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -34,6 +34,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.h ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/silu.cpp + ${MLAS_SRC_DIR}/gelu.cpp ${MLAS_SRC_DIR}/compute.cpp ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp @@ -201,6 +203,14 @@ function(setup_mlas_source_for_windows) ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + set(mlas_platform_srcs_avx512 + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + + set_source_files_properties(${mlas_platform_srcs_avx512} PROPERTIES COMPILE_FLAGS "/arch:AVX512") + target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/dgemm.cpp ${mlas_platform_srcs_avx} @@ -212,7 +222,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${mlas_platform_srcs_avx512} ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp @@ -764,6 +774,8 @@ endif() ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp ) set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index f00fad809968f..71e0e8561e110 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -78,22 +78,22 @@ class QuickGelu : public OpKernel { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - if (alpha_ != 1.0f) { - // TODO: Consider vectorizing this scalar multiplication. - // It needs exposing a new API in MLAS to take in a scalar - // that will be used in the elementwise multiplication. - // Estimate the cost-benefit tradeoff before proceeding - // with that optimization. - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * alpha_; - } - - MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); - } else { - // SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f - MlasComputeLogistic(p_input, p_output, onnxruntime::narrow(count)); + if (alpha_ == 1.0f) { + MlasComputeSilu(p_input, p_output, onnxruntime::narrow(count)); + return; } + // TODO: Consider vectorizing this scalar multiplication. + // It needs exposing a new API in MLAS to take in a scalar + // that will be used in the elementwise multiplication. + // Estimate the cost-benefit tradeoff before proceeding + // with that optimization. + for (int64_t i = 0; i < count; i++) { + p_output[i] = p_input[i] * alpha_; + } + + MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + MlasEltwiseMul(p_input, p_output, p_output, onnxruntime::narrow(count)); }, 0); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 56849995656f3..2b446c4b2601b 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1113,6 +1113,30 @@ MlasComputeErf( size_t N ); +// +// Note: The Input and Output buffers for MlasComputeGeluErf must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N + ); + +// +// Note: The Input and Output buffers for MlasComputeSilu must not overlap. +// In-place operation (e.g., passing the same buffer for both parameters) is unsupported. +// +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ); + template void MLASCALL diff --git a/onnxruntime/core/mlas/lib/gelu.cpp b/onnxruntime/core/mlas/lib/gelu.cpp new file mode 100644 index 0000000000000..dc25611652c77 --- /dev/null +++ b/onnxruntime/core/mlas/lib/gelu.cpp @@ -0,0 +1,65 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu.cpp + +Abstract: + + This module implements routines to compute the exact Gelu function. + +--*/ + +#include "mlasi.h" + +namespace { + +constexpr float kInvSqrt2 = 0.70710678118654752440f; + +} // namespace + + +void +MLASCALL +MlasGeluErfKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe because it is implemented in + // multiple passes: first scale Input into Output, then apply erf in place, + // and finally combine that intermediate with the original Input values. + // Callers must guarantee that Input and Output do not overlap (see mlas.h for aliasing requirements). + for (size_t i = 0; i < N; ++i) { + Output[i] = Input[i] * kInvSqrt2; + } + + MlasComputeErf(Output, Output, N); + + for (size_t i = 0; i < N; ++i) { + Output[i] = 0.5f * Input[i] * (Output[i] + 1.0f); + } +} + +void +MLASCALL +MlasComputeGeluErf( + const float* Input, + float* Output, + size_t N + ) +{ +#if defined(MLAS_TARGET_AMD64) + // TODO: Add an intermediate fused AVX2/FMA3 GELU(erf) path on AMD64. + // Today the dispatch jumps from the generic multi-pass implementation to + // AVX512F, so non-AVX512 x64 machines fall back to the generic kernel. + GetMlasPlatform().GeluErfKernelRoutine(Input, Output, N); +#else + MlasGeluErfKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp new file mode 100644 index 0000000000000..4a9f3a100ed65 --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/gelu_avx512f.cpp @@ -0,0 +1,219 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + gelu_avx512f.cpp + +Abstract: + + This module implements routines to compute exact Gelu with AVX512F + intrinsics. + +--*/ + +#include + +#include "mlasi.h" + +namespace { + +struct GeluAvx512Constants { + static constexpr int32_t SignBitMask = INT32_MIN; + static constexpr float InvSqrt2 = 0.70710678118654752440f; + static constexpr float Half = 0.5f; + static constexpr float One = 1.0f; + + static constexpr float ErfUpperAbsRange = 3.925f; + static constexpr float ErfSplitBoundary = 0.921875f; + static constexpr float ErfSMALL_P0 = -5.99104969e-4f; + static constexpr float ErfSMALL_P1 = 4.99339588e-3f; + static constexpr float ErfSMALL_P2 = -2.67667342e-2f; + static constexpr float ErfSMALL_P3 = 1.12818025e-1f; + static constexpr float ErfSMALL_P4 = -3.76124859e-1f; + static constexpr float ErfSMALL_P5_Minus_One = 1.28379151e-1f; + static constexpr float ErfBIG_P0 = 1.72948930e-5f; + static constexpr float ErfBIG_P1 = -3.83208680e-4f; + static constexpr float ErfBIG_P2 = 3.88393435e-3f; + static constexpr float ErfBIG_P3 = -2.42545605e-2f; + static constexpr float ErfBIG_P4 = 1.06777847e-1f; + static constexpr float ErfBIG_P5 = 6.34846687e-1f; + static constexpr float ErfBIG_P6_Minus_One = 1.28717512e-1f; + static constexpr float ErfOne = 1.0f; + static constexpr float ExpLowerRange = -88.3762626647949f; + static constexpr float ExpLog2Reciprocal = 1.44269504088896341f; + static constexpr float ExpLog2Hi = -6.93145752e-1f; + static constexpr float ExpLog2Lo = -1.42860677e-6f; + static constexpr float ExpP0 = 1.38319808e-3f; + static constexpr float ExpP1 = 8.37550033e-3f; + static constexpr float ExpP2 = 4.16689515e-2f; + static constexpr float ExpP3 = 1.66664466e-1f; + static constexpr float ExpP4 = 4.99999851e-1f; + static constexpr float ExpP5 = 1.0f; + static constexpr float ExpP6 = 1.0f; + static constexpr float ExpC = 1.25829120e+7f; +}; + +struct GeluAvx512BroadcastConstants { + const __m512 NegZero = _mm512_castsi512_ps(_mm512_set1_epi32(GeluAvx512Constants::SignBitMask)); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 InvSqrt2 = _mm512_set1_ps(GeluAvx512Constants::InvSqrt2); + const __m512 Half = _mm512_set1_ps(GeluAvx512Constants::Half); + const __m512 One = _mm512_set1_ps(GeluAvx512Constants::One); + const __m512 ErfUpperAbsRange = _mm512_set1_ps(GeluAvx512Constants::ErfUpperAbsRange); + const __m512 ErfSplitBoundary = _mm512_set1_ps(GeluAvx512Constants::ErfSplitBoundary); + const __m512 ErfSmallP0 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P0); + const __m512 ErfSmallP1 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P1); + const __m512 ErfSmallP2 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P2); + const __m512 ErfSmallP3 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P3); + const __m512 ErfSmallP4 = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P4); + const __m512 ErfSmallP5MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfSMALL_P5_Minus_One); + const __m512 ErfBigP0 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P0); + const __m512 ErfBigP1 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P1); + const __m512 ErfBigP2 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P2); + const __m512 ErfBigP3 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P3); + const __m512 ErfBigP4 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P4); + const __m512 ErfBigP5 = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P5); + const __m512 ErfBigP6MinusOne = _mm512_set1_ps(GeluAvx512Constants::ErfBIG_P6_Minus_One); + const __m512 ErfOne = _mm512_set1_ps(GeluAvx512Constants::ErfOne); + const __m512 ExpLowerRange = _mm512_set1_ps(GeluAvx512Constants::ExpLowerRange); + const __m512 ExpLog2Reciprocal = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Reciprocal); + const __m512 ExpLog2Hi = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Hi); + const __m512 ExpLog2Lo = _mm512_set1_ps(GeluAvx512Constants::ExpLog2Lo); + const __m512 ExpP0 = _mm512_set1_ps(GeluAvx512Constants::ExpP0); + const __m512 ExpP1 = _mm512_set1_ps(GeluAvx512Constants::ExpP1); + const __m512 ExpP2 = _mm512_set1_ps(GeluAvx512Constants::ExpP2); + const __m512 ExpP3 = _mm512_set1_ps(GeluAvx512Constants::ExpP3); + const __m512 ExpP4 = _mm512_set1_ps(GeluAvx512Constants::ExpP4); + const __m512 ExpP5 = _mm512_set1_ps(GeluAvx512Constants::ExpP5); + const __m512 ExpP6 = _mm512_set1_ps(GeluAvx512Constants::ExpP6); + const __m512 ExpC = _mm512_set1_ps(GeluAvx512Constants::ExpC); +}; + +MLAS_FORCEINLINE __m512 +MlasGeluErfExpVectorAvx512( + __m512 Value, + const GeluAvx512BroadcastConstants& Constants + ) +{ + __m512 R = _mm512_fmadd_ps(Constants.ExpLog2Reciprocal, Value, Constants.ExpC); + R = _mm512_sub_ps(R, Constants.ExpC); + + __m512 Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Hi, Value); + Fx = _mm512_fmadd_ps(R, Constants.ExpLog2Lo, Fx); + + __m512 Y = Constants.ExpP0; + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP1); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP2); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP3); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP4); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP5); + Y = _mm512_fmadd_ps(Y, Fx, Constants.ExpP6); + Y = _mm512_scalef_ps(Y, R); + + return Y; +} + +MLAS_FORCEINLINE __m512 +MlasGeluErfAvx512( + __m512 Value, + const GeluAvx512BroadcastConstants& Constants + ) +{ + const __m512 SignMask = _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(Value), _mm512_castps_si512(Constants.NegZero))); + __m512 AbsValue = _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(Constants.NegZero), _mm512_castps_si512(Value))); + AbsValue = _mm512_min_ps(Constants.ErfUpperAbsRange, AbsValue); + + const __m512 SquareValue = _mm512_mul_ps(AbsValue, AbsValue); + + __m512 SmallResult = Constants.ErfSmallP0; + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP1); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP2); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP3); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP4); + SmallResult = _mm512_fmadd_ps(SmallResult, SquareValue, Constants.ErfSmallP5MinusOne); + SmallResult = _mm512_fmadd_ps(SmallResult, AbsValue, AbsValue); + + const __mmask16 SplitMask = _mm512_cmp_ps_mask(AbsValue, Constants.ErfSplitBoundary, _CMP_GT_OQ); + const __m512 BigInput = _mm512_mask_blend_ps(SplitMask, Constants.Zero, AbsValue); + + __m512 BigResult = Constants.ErfBigP0; + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP1); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP2); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP3); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP4); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP5); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, Constants.ErfBigP6MinusOne); + BigResult = _mm512_fmadd_ps(BigResult, BigInput, BigInput); + + BigResult = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(BigResult), _mm512_castps_si512(Constants.NegZero))); + BigResult = _mm512_max_ps(Constants.ExpLowerRange, BigResult); + BigResult = _mm512_sub_ps(Constants.ErfOne, MlasGeluErfExpVectorAvx512(BigResult, Constants)); + + __m512 Result = _mm512_mask_blend_ps(SplitMask, SmallResult, BigResult); + Result = _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(Result), _mm512_castps_si512(SignMask))); + return Result; +} + +MLAS_FORCEINLINE __m512 +MlasComputeGeluVectorExactAvx512( + __m512 X, + const GeluAvx512BroadcastConstants& Constants + ) +{ + const __m512 ErfInput = _mm512_mul_ps(X, Constants.InvSqrt2); + const __m512 ErfValue = MlasGeluErfAvx512(ErfInput, Constants); + __m512 Result = _mm512_mul_ps(_mm512_mul_ps(Constants.Half, X), _mm512_add_ps(ErfValue, Constants.One)); + + // Preserve NaN payload/sign behavior explicitly because the erf + // approximation uses min/max style range limiting that is not guaranteed to + // preserve NaNs the same way as the existing MLAS GELU semantics. + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + return Result; +} + +void +MlasGeluErfKernelAvx512FExactImpl( + const float* Input, + float* Output, + size_t N + ) +{ + const GeluAvx512BroadcastConstants Constants; + while (N >= 16) { + const __m512 X = _mm512_loadu_ps(Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); + + _mm512_storeu_ps(Output, Result); + + Input += 16; + Output += 16; + N -= 16; + } + + if (N > 0) { + const __mmask16 TailMask = __mmask16((1u << static_cast(N)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input); + const __m512 Result = MlasComputeGeluVectorExactAvx512(X, Constants); + + _mm512_mask_storeu_ps(Output, TailMask, Result); + } +} + +} // namespace + +void +MLASCALL +MlasGeluErfKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + MlasGeluErfKernelAvx512FExactImpl(Input, Output, N); +} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp new file mode 100644 index 0000000000000..7e8424d94827a --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx512/silu_avx512f.cpp @@ -0,0 +1,140 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu_avx512f.cpp + +Abstract: + + This module implements routines to compute the SiLU function with AVX512F + intrinsics. + +--*/ + +#include "mlasi.h" + +namespace { + +struct SiluAvx512Constants { + static constexpr float LogisticLowerRange = -18.0f; + static constexpr float LogisticUpperRange = 18.0f; + static constexpr float Alpha9 = 4.37031012579801e-11f; + static constexpr float Alpha7 = 1.15627324459942e-07f; + static constexpr float Alpha5 = 6.08574864600143e-05f; + static constexpr float Alpha3 = 8.51377133304701e-03f; + static constexpr float Alpha1 = 2.48287947061529e-01f; + static constexpr float Beta10 = 6.10247389755681e-13f; + static constexpr float Beta8 = 5.76102136993427e-09f; + static constexpr float Beta6 = 6.29106785017040e-06f; + static constexpr float Beta4 = 1.70198817374094e-03f; + static constexpr float Beta2 = 1.16817656904453e-01f; + static constexpr float Beta0 = 9.93151921023180e-01f; + static constexpr float OneHalf = 0.5f; +}; + +struct SiluAvx512BroadcastConstants { + const __m512 LogisticLowerRange = _mm512_set1_ps(SiluAvx512Constants::LogisticLowerRange); + const __m512 LogisticUpperRange = _mm512_set1_ps(SiluAvx512Constants::LogisticUpperRange); + const __m512 Alpha9 = _mm512_set1_ps(SiluAvx512Constants::Alpha9); + const __m512 Alpha7 = _mm512_set1_ps(SiluAvx512Constants::Alpha7); + const __m512 Alpha5 = _mm512_set1_ps(SiluAvx512Constants::Alpha5); + const __m512 Alpha3 = _mm512_set1_ps(SiluAvx512Constants::Alpha3); + const __m512 Alpha1 = _mm512_set1_ps(SiluAvx512Constants::Alpha1); + const __m512 Beta10 = _mm512_set1_ps(SiluAvx512Constants::Beta10); + const __m512 Beta8 = _mm512_set1_ps(SiluAvx512Constants::Beta8); + const __m512 Beta6 = _mm512_set1_ps(SiluAvx512Constants::Beta6); + const __m512 Beta4 = _mm512_set1_ps(SiluAvx512Constants::Beta4); + const __m512 Beta2 = _mm512_set1_ps(SiluAvx512Constants::Beta2); + const __m512 Beta0 = _mm512_set1_ps(SiluAvx512Constants::Beta0); + const __m512 OneHalf = _mm512_set1_ps(SiluAvx512Constants::OneHalf); + const __m512 Zero = _mm512_setzero_ps(); + const __m512 One = _mm512_set1_ps(1.0f); +}; + +MLAS_FORCEINLINE __m512 +MlasLogisticApproxAvx512( + __m512 Value, + const SiluAvx512BroadcastConstants& Constants + ) +{ + // Mirror MlasComputeLogistic by evaluating the same clamped rational + // approximation in-register and then multiplying by x for SiLU. + const __m512 ClampedValue = _mm512_max_ps(_mm512_min_ps(Value, Constants.LogisticUpperRange), Constants.LogisticLowerRange); + const __m512 ValueSquared = _mm512_mul_ps(ClampedValue, ClampedValue); + + __m512 P = _mm512_fmadd_ps(ValueSquared, Constants.Alpha9, Constants.Alpha7); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha5); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha3); + P = _mm512_fmadd_ps(P, ValueSquared, Constants.Alpha1); + P = _mm512_mul_ps(P, ClampedValue); + + __m512 Q = _mm512_fmadd_ps(ValueSquared, Constants.Beta10, Constants.Beta8); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta6); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta4); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta2); + Q = _mm512_fmadd_ps(Q, ValueSquared, Constants.Beta0); + + __m512 Logistic = _mm512_add_ps(_mm512_div_ps(P, Q), Constants.OneHalf); + Logistic = _mm512_min_ps(_mm512_max_ps(Logistic, Constants.Zero), Constants.One); + + return Logistic; +} + +MLAS_FORCEINLINE __m512 +MlasComputeSiluVectorAvx512( + __m512 X, + const SiluAvx512BroadcastConstants& Constants + ) +{ + __m512 Result = _mm512_mul_ps(X, MlasLogisticApproxAvx512(X, Constants)); + + // Preserve NaN payload/sign behavior explicitly because the clamped + // logistic approximation uses min/max operations that do not reliably + // propagate NaNs the same way as the existing MLAS SiLU semantics. + const __mmask16 NaNMask = _mm512_cmp_ps_mask(X, X, _CMP_UNORD_Q); + Result = _mm512_mask_mov_ps(Result, NaNMask, X); + + return Result; +} + +} // namespace + +void +MLASCALL +MlasSiluKernelAvx512F( + const float* Input, + float* Output, + size_t N + ) +{ + const SiluAvx512BroadcastConstants Constants; + size_t Offset = 0; + + while (Offset + 32 <= N) { + const __m512 X0 = _mm512_loadu_ps(Input + Offset); + const __m512 X1 = _mm512_loadu_ps(Input + Offset + 16); + const __m512 Result0 = MlasComputeSiluVectorAvx512(X0, Constants); + const __m512 Result1 = MlasComputeSiluVectorAvx512(X1, Constants); + _mm512_storeu_ps(Output + Offset, Result0); + _mm512_storeu_ps(Output + Offset + 16, Result1); + Offset += 32; + } + + while (Offset + 16 <= N) { + const __m512 X = _mm512_loadu_ps(Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); + _mm512_storeu_ps(Output + Offset, Result); + Offset += 16; + } + + if (Offset < N) { + const __mmask16 TailMask = static_cast<__mmask16>((1u << (N - Offset)) - 1u); + const __m512 X = _mm512_maskz_loadu_ps(TailMask, Input + Offset); + const __m512 Result = MlasComputeSiluVectorAvx512(X, Constants); + _mm512_mask_storeu_ps(Output + Offset, TailMask, Result); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 954849fe90049..0dab8e41f25cd 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1096,6 +1096,8 @@ extern "C" { #endif MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernel; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasLogisticKernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasTanhKernel; @@ -1126,6 +1128,8 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernelAvx512F; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernelAvx512F; #endif MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel; @@ -1477,6 +1481,8 @@ struct MLAS_PLATFORM { MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; #endif #if defined(MLAS_TARGET_AMD64) + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ac3761d63bd20..eccde79848e61 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -283,7 +283,9 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelSse; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; + this->GeluErfKernelRoutine = MlasGeluErfKernel; this->LogisticKernelRoutine = MlasLogisticKernel; + this->SiluKernelRoutine = MlasSiluKernel; this->TanhKernelRoutine = MlasTanhKernel; this->ErfKernelRoutine = MlasErfKernel; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; @@ -459,7 +461,8 @@ Return Value: // if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { - + this->GeluErfKernelRoutine = MlasGeluErfKernelAvx512F; + this->SiluKernelRoutine = MlasSiluKernelAvx512F; this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx512F; diff --git a/onnxruntime/core/mlas/lib/silu.cpp b/onnxruntime/core/mlas/lib/silu.cpp new file mode 100644 index 0000000000000..96686e4bdf1da --- /dev/null +++ b/onnxruntime/core/mlas/lib/silu.cpp @@ -0,0 +1,51 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + silu.cpp + +Abstract: + + This module implements routines to compute the SiLU function. + +--*/ + +#include "mlasi.h" + +void +MLASCALL +MlasSiluKernel( + const float* Input, + float* Output, + size_t N + ) +{ + // This kernel is not buffer alias safe because it is implemented in two + // passes: first compute logistic(Input) into Output, then multiply that + // intermediate by the original Input values. Callers must guarantee that + // Input and Output do not overlap (see mlas.h for aliasing requirements). + MlasComputeLogistic(Input, Output, N); + MlasEltwiseMul(Input, Output, Output, N); +} + +void +MLASCALL +MlasComputeSilu( + const float* Input, + float* Output, + size_t N + ) +{ +#if defined(MLAS_TARGET_AMD64) + // TODO: Add an intermediate fused AVX2/FMA3 SiLU path on AMD64. Today the + // dispatch jumps from the generic two-pass implementation to AVX512F, so + // non-AVX512 x64 machines fall back to the generic kernel. + GetMlasPlatform().SiluKernelRoutine(Input, Output, N); +#else + MlasSiluKernel(Input, Output, N); +#endif +} diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index d55973eda180f..e34af83d1f29e 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -88,16 +88,9 @@ Status Gelu::Compute(OpKernelContext* context) const { T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } + // MlasComputeGeluErf requires distinct input/output buffers. This + // call uses disjoint slices from the input and output tensors. + MlasComputeGeluErf(p_input, p_output, narrow(count)); }, 0); return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h index 13238028d878a..14a070609a69b 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.h +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once + namespace onnxruntime { template diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp new file mode 100644 index 0000000000000..f7e461c29843a --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include + +#include "mlas.h" +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +namespace { + +// Compare fused MLAS unary activation paths against unfused baselines for +// SiLU and exact GELU(erf). + +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kInvSqrt2 = 0.7071067811865475244f; +constexpr int64_t kFusedBytesPerElement = 2; +constexpr int64_t kSiluUnfusedBytesPerElement = 5; +constexpr int64_t kGeluUnfusedBytesPerElement = 7; + +struct DispatchedUnaryPathInfo { + int64_t bytes_per_element; + const char* label; +}; + +DispatchedUnaryPathInfo GetSiluDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic path, + // which materializes the logistic result before the final multiply. + return {kSiluUnfusedBytesPerElement, "generic_fallback"}; +} + +DispatchedUnaryPathInfo GetGeluErfDispatchPathInfo() { +#if defined(MLAS_TARGET_AMD64) + if (GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F) { + return {kFusedBytesPerElement, "avx512_fused"}; + } +#endif + + // The current non-AVX512 dispatch target falls back to the generic exact + // GELU(erf) implementation, which uses separate scale, erf, and final passes. + return {kGeluUnfusedBytesPerElement, "generic_fallback"}; +} + +std::vector MakeInput(size_t n, float min_value, float max_value) { + auto data = RandomVectorUniform(n, min_value, max_value); + + if (!data.empty()) { + data[0] = 0.0f; + } + if (data.size() > 1) { + data[1] = -0.0f; + } + if (data.size() > 2) { + data[2] = -1.0f; + } + if (data.size() > 3) { + data[3] = 1.0f; + } + + return data; +} + +template +void RunDispatchedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value, + DispatchedUnaryPathInfo path_info) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + state.SetLabel(path_info.label); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * path_info.bytes_per_element; + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); +} + +template +void RunUnfusedUnaryBenchmark(benchmark::State& state, + KernelFn&& kernel, + float min_value, + float max_value, + int64_t bytes_per_element) { + const auto n = static_cast(state.range(0)); + auto input = MakeInput(n, min_value, max_value); + std::vector output(n); + + kernel(input.data(), output.data(), n); + + for (auto _ : state) { + kernel(input.data(), output.data(), n); + benchmark::DoNotOptimize(output.data()); + benchmark::ClobberMemory(); + } + + const int64_t bytes_per_iteration = static_cast(n) * static_cast(sizeof(float)) * bytes_per_element; + state.SetItemsProcessed(static_cast(state.iterations()) * static_cast(n)); + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_per_iteration); +} + +static void UnaryKernelArgs(benchmark::internal::Benchmark* b) { + for (int n : {1, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 511, 512, 1024, 4096, 16384, 65536, 262144}) { + b->Arg(n); + } +} + +void BM_SiluDispatch(benchmark::State& state) { + // Fused MLAS SiLU entry point. On supported platforms this may dispatch to a + // specialized implementation that combines the activation into a single + // kernel instead of exposing intermediate results. + RunDispatchedUnaryBenchmark(state, MlasComputeSilu, kSiluMinValue, kSiluMaxValue, GetSiluDispatchPathInfo()); +} + +void BM_SiluUnfusedDispatch(benchmark::State& state) { + // Unfused SiLU baseline: compute logistic(x) first and then multiply by x in + // a separate elementwise pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeLogistic(input, output, n); + MlasEltwiseMul(input, output, output, n); + }, + kSiluMinValue, + kSiluMaxValue, + kSiluUnfusedBytesPerElement); +} + +void BM_GeluErfDispatchExact(benchmark::State& state) { + // Fused MLAS GELU(erf) entry point using the exact erf-based formulation. + // On AMD64 this goes through the platform dispatch layer and may select an + // architecture-specific implementation. + RunDispatchedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + MlasComputeGeluErf(input, output, n); + }, + kGeluMinValue, + kGeluMaxValue, + GetGeluErfDispatchPathInfo()); +} + +void BM_GeluErfUnfusedExact(benchmark::State& state) { + // Unfused exact GELU(erf) baseline: scale by 1/sqrt(2), run erf, then apply the + // final 0.5 * x * (erf(x / sqrt(2)) + 1) transform in a separate pass. + RunUnfusedUnaryBenchmark( + state, + [](const float* input, float* output, size_t n) { + for (size_t i = 0; i < n; ++i) { + output[i] = input[i] * kInvSqrt2; + } + + MlasComputeErf(output, output, n); + + for (size_t i = 0; i < n; ++i) { + output[i] = 0.5f * input[i] * (output[i] + 1.0f); + } + }, + kGeluMinValue, + kGeluMaxValue, + kGeluUnfusedBytesPerElement); +} + +} // namespace + +BENCHMARK(BM_SiluDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_SiluUnfusedDispatch)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfDispatchExact)->Apply(UnaryKernelArgs)->UseRealTime(); +BENCHMARK(BM_GeluErfUnfusedExact)->Apply(UnaryKernelArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp new file mode 100644 index 0000000000000..e87768ce3e660 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_transcendental_avx512.cpp @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" + +#include +#include + +#if defined(MLAS_TARGET_AMD64) + +namespace { + +constexpr float kGeluMinValue = -10.0f; +constexpr float kGeluMaxValue = 10.0f; +constexpr float kSiluMinValue = -20.0f; +constexpr float kSiluMaxValue = 20.0f; + +constexpr float kGeluAbsoluteTolerance = 2e-6f; +constexpr float kGeluRelativeTolerance = 2e-5f; +constexpr float kSiluAbsoluteTolerance = 3e-5f; +constexpr float kSiluRelativeTolerance = 5e-5f; + +constexpr std::array kShortTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255}; + +constexpr std::array kLongTestSizes = { + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, 32, 33, 63, + 64, 65, 127, 128, 129, 255, 511, 512, 513, 1023, 1024, 1025, 4095}; + +bool IsGeluErfAvx512Dispatched() { + return GetMlasPlatform().GeluErfKernelRoutine == MlasGeluErfKernelAvx512F; +} + +bool IsSiluAvx512Dispatched() { + return GetMlasPlatform().SiluKernelRoutine == MlasSiluKernelAvx512F; +} + +bool UnaryOutputsMatch(float actual, float expected, float absolute_tolerance, float relative_tolerance, + bool check_signed_zero) { + if (std::isnan(expected)) { + return std::isnan(actual); + } + + if (std::isinf(expected)) { + return std::isinf(actual) && (std::signbit(actual) == std::signbit(expected)); + } + + if (check_signed_zero && actual == 0.0f && expected == 0.0f) { + return std::signbit(actual) == std::signbit(expected); + } + + const float diff = std::fabs(actual - expected); + if (diff <= absolute_tolerance) { + return true; + } + + const float scale = std::max(std::fabs(actual), std::fabs(expected)); + return scale > 0.0f && diff <= scale * relative_tolerance; +} + +const std::vector& GetGeluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 0.0f, + -0.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + }; + + return values; +} + +const std::vector& GetSiluSpecialValues() { + static const std::vector values = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::numeric_limits::max(), + -std::numeric_limits::max(), + 1.0e9f, + -1.0e9f, + 0.0f, + -0.0f, + -20.0f, + -10.0f, + -6.0f, + -3.0f, + -1.0f, + -0.5f, + 0.5f, + 1.0f, + 3.0f, + 6.0f, + 10.0f, + 20.0f, + }; + + return values; +} + +void FillInput(float* input, size_t n, float minimum_value, float maximum_value, + const std::vector& special_values, uint32_t seed) { + std::mt19937 generator(seed); + std::uniform_real_distribution distribution(minimum_value, maximum_value); + + for (size_t i = 0; i < n; ++i) { + input[i] = distribution(generator); + } + + const size_t special_count = std::min(n, special_values.size()); + for (size_t i = 0; i < special_count; ++i) { + input[i] = special_values[i]; + } +} + +class MlasComputeGeluErfAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer public_output_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsGeluErfAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F GELU(erf) dispatch is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kGeluMinValue, kGeluMaxValue, GetGeluSpecialValues(), + static_cast(size * 131u + iteration * 977u + 17u)); + + MlasGeluErfKernel(input, generic_output, size); + MlasComputeGeluErf(input, public_output, size); + MlasGeluErfKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public GELU(erf) mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "GELU(erf) mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kGeluAbsoluteTolerance, kGeluRelativeTolerance, true)) + << "Public/API GELU(erf) dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Gelu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +class MlasComputeSiluAvx512Test : public MlasTestBase { + private: + MatrixGuardBuffer input_buffer_; + MatrixGuardBuffer generic_output_buffer_; + MatrixGuardBuffer public_output_buffer_; + MatrixGuardBuffer avx512_output_buffer_; + + void ExecuteCommon(const std::vector& sizes, size_t iterations) { + if (!IsSiluAvx512Dispatched()) { + GTEST_SKIP() << "AVX512F SiLU dispatch is not available on this machine."; + } + + for (size_t size : sizes) { + for (size_t iteration = 0; iteration < iterations; ++iteration) { + float* input = input_buffer_.GetBuffer(size); + float* generic_output = generic_output_buffer_.GetBuffer(size); + float* public_output = public_output_buffer_.GetBuffer(size); + float* avx512_output = avx512_output_buffer_.GetBuffer(size); + + FillInput(input, size, kSiluMinValue, kSiluMaxValue, GetSiluSpecialValues(), + static_cast(size * 149u + iteration * 991u + 31u)); + + MlasSiluKernel(input, generic_output, size); + MlasComputeSilu(input, public_output, size); + MlasSiluKernelAvx512F(input, avx512_output, size); + + for (size_t i = 0; i < size; ++i) { + ASSERT_TRUE(UnaryOutputsMatch(public_output[i], generic_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", public=" << public_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(public_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], generic_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Silu mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", generic=" << generic_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - generic_output[i]); + + ASSERT_TRUE(UnaryOutputsMatch(avx512_output[i], public_output[i], + kSiluAbsoluteTolerance, kSiluRelativeTolerance, true)) + << "Public/API Silu dispatch mismatch at index " << i << " of " << size + << ", input=" << input[i] + << ", avx512=" << avx512_output[i] + << ", public=" << public_output[i] + << ", abs_diff=" << std::fabs(avx512_output[i] - public_output[i]); + } + } + } + } + + public: + static const char* GetTestSuiteName() { + return "TranscendentalAvx512Silu"; + } + + void ExecuteShort() override { + ExecuteCommon(std::vector(kShortTestSizes.begin(), kShortTestSizes.end()), 3); + } + + void ExecuteLong() override { + ExecuteCommon(std::vector(kLongTestSizes.begin(), kLongTestSizes.end()), 8); + } +}; + +} // namespace + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } else { + count += MlasLongExecuteTests::RegisterLongExecute(); + count += MlasLongExecuteTests::RegisterLongExecute(); + } + return count; +}); + +#else + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool) { + return size_t{0}; +}); + +#endif