diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index 7e64235d3fc3d..f00fad809968f 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -77,17 +77,27 @@ class QuickGelu : public OpKernel { const T* p_input = input_data + start; T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * alpha_; - } - MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + if (alpha_ != 1.0f) { + // TODO: Consider vectorizing this scalar multiplication. + // It needs exposing a new API in MLAS to take in a scalar + // that will be used in the elementwise multiplication. + // Estimate the cost-benefit tradeoff before proceeding + // with that optimization. + for (int64_t i = 0; i < count; i++) { + p_output[i] = p_input[i] * alpha_; + } - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * p_output[i]; + MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + } else { + // SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f + MlasComputeLogistic(p_input, p_output, onnxruntime::narrow(count)); } + + MlasEltwiseMul(p_input, p_output, p_output, onnxruntime::narrow(count)); }, 0); + return Status::OK(); } diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 248c6d74e6cbd..4923e8331e64d 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1126,6 +1126,16 @@ MlasEltwiseAdd( size_t N ); +template +void +MLASCALL +MlasEltwiseMul( + const T* left, + const T* right, + T* output, + size_t N + ); + template void MLASCALL diff --git a/onnxruntime/core/mlas/lib/eltwise.cpp b/onnxruntime/core/mlas/lib/eltwise.cpp index f63d71b40bfbb..82457deb811a2 100644 --- a/onnxruntime/core/mlas/lib/eltwise.cpp +++ b/onnxruntime/core/mlas/lib/eltwise.cpp @@ -53,6 +53,38 @@ MlasEltwiseAdd( } } +template <> +void +MLASCALL +MlasEltwiseMul( + const float* left, + const float* right, + float* output, + size_t N +) { + while (N > 0) { + if (N >= 4) { + MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left); + MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right); + + MLAS_FLOAT32X4 ResultVec = MlasMultiplyFloat32x4(LeftVec, RightVec); + + MlasStoreFloat32x4(output, ResultVec); + + left += 4; + right += 4; + output += 4; + N -= 4; + } else { + *output = *left * *right; + + left += 1; + right += 1; + output += 1; + N -= 1; + } + } +} template <> void diff --git a/onnxruntime/test/mlas/unittest/test_eltwise.cpp b/onnxruntime/test/mlas/unittest/test_eltwise.cpp index c4d4b9c0eb317..136d3a9a756b4 100644 --- a/onnxruntime/test/mlas/unittest/test_eltwise.cpp +++ b/onnxruntime/test/mlas/unittest/test_eltwise.cpp @@ -97,10 +97,62 @@ class MlasEltwiseAddTest : public MlasTestBase { } }; +class MlasEltwiseMulTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputLeft; + MatrixGuardBuffer BufferInputRight; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + float* InputLeft = BufferInputLeft.GetBuffer(N); + float* InputRight = BufferInputRight.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = distribution(generator); + InputRight[n] = ScalarValue.value_or(distribution(generator)); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = InputLeft[n] * InputRight[n]; + } + + MlasEltwiseMul(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Eltwise_Mul"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); + Test(n, -10.f, 10.f, -5000.f); + } + } +}; + static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); } return count; });