diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index f9c8a423be241..695afc661dd99 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -71,8 +71,8 @@ jobs: run: | set -e -x BINARY_SIZE_THRESHOLD_ARGS="" - echo "Binary size threshold in bytes: 1436672" - BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672" + echo "Binary size threshold in bytes: 1722565" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1722565" # Ensure ANDROID_NDK_HOME is available and get its real path if [ -z "$ANDROID_NDK_HOME" ]; then diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 3530ab03c822a..b922a78a6929d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -109,6 +109,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -431,6 +433,8 @@ else() ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_kernel_neon.cpp ) if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 4ad13f072ae72..4cad44a56ba96 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -949,6 +949,15 @@ extern "C" { #if defined(__aarch64__) && defined(__linux__) MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif +#if defined(MLAS_TARGET_ARM64) + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon; #endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; @@ -1335,6 +1344,12 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch; + MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; + MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; + uint32_t NchwcBlockSize; #endif const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr}; @@ -1395,6 +1410,7 @@ struct MLAS_PLATFORM { int32_t MaximumThreadCount; #elif defined(MLAS_TARGET_ARM64) static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT * 4; + static constexpr size_t MLAS_NEON_NCHWC_BLOCK_SIZE = 16; #else static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; #endif diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index c4b8d5e78a491..923e513ccb07a 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -558,6 +558,15 @@ Return Value: this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; this->EltwiseDispatch = &MlasEltwiseDispatchNeon; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon; + this->NchwcBlockSize = MLAS_NEON_NCHWC_BLOCK_SIZE; + // // Check if the processor supports ASIMD dot product instructions. // diff --git a/onnxruntime/core/mlas/lib/sconv.h b/onnxruntime/core/mlas/lib/sconv.h new file mode 100644 index 0000000000000..94e657638975a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sconv.h @@ -0,0 +1,25 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv.h + +Abstract: + + This module defines convolution kernel flags for configuring convolution + operations including output accumulation, bias addition, and activations. + +--*/ + +// +// Define the convolution kernel flags. +// + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp new file mode 100644 index 0000000000000..3ecad66a32886 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp @@ -0,0 +1,520 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv_kernel_neon.cpp + +Abstract: + + This module implements the single precision convolution kernels for ARM NEON. + +--*/ + +#include "mlasi.h" +#include "sconv.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +// Common implementation for NCHW and NCHWC convolution kernels +template +void + MLASCALL + MlasConvFloatKernelNeonImpl( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + (void)InputStride; + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + for (size_t filterSetBlock = 0; filterSetBlock < FilterCount; filterSetBlock++) { + const float* filter = Filter + filterSetBlock * FilterStrideElements; + float* output = Output + filterSetBlock * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (IsNchwcFormat) { + for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) { + const float* input_element = input_base + filterBlock; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) { + input_value = *input_element; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) + + kw * (BlockSize * BlockSize) + + filterBlock * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } else { + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) { + input_value = *input_base; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * KernelWidth + kw; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} + +void + MLASCALL + MlasConvNchwFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Implementation of MlasConvNchwcFloatKernelNeon +// + +void + MLASCALL + MlasConvNchwcFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Helper function to load input vector with bounds checking +// +static inline float32x4_t +LoadInputVectorWithBounds( + const float* input_base, + size_t offset, + bool is_main_region, + const float* InputBase, + size_t kh, + size_t DilatedInputWidthElements, + size_t InputWidthElements +) +{ + if (is_main_region) { + return MlasLoadFloat32x4(input_base + offset); + } else { + float input_values[4]; + for (size_t i = 0; i < 4; i++) { + const float* input_element = input_base + offset + i; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + if (input_element >= input_row_start && input_element < input_row_end) { + input_values[i] = *input_element; + } else { + input_values[i] = 0.0f; + } + } + return MlasLoadFloat32x4(input_values); + } +} + +// +// Implementation of MlasConvDepthwiseFloatKernelNeon +// +// This kernel performs depthwise separable convolution where each input channel +// is convolved with its own filter. This is more efficient than standard convolution +// for certain network architectures like MobileNets. +// + +void + MLASCALL + MlasConvDepthwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + (void)InputStrideElements; + + const size_t InputWidthElements = InputWidth / sizeof(float); + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&Output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(Bias); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(Bias + 4); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(Bias + 8); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(Bias + 12); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + size_t kernel_pos = kh * KernelWidth + kw; + + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + float32x4_t InputVector0 = LoadInputVectorWithBounds(input_base, 0, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector1 = LoadInputVectorWithBounds(input_base, 4, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector2 = LoadInputVectorWithBounds(input_base, 8, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector3 = LoadInputVectorWithBounds(input_base, 12, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector0, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector1, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector2, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector3, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], Accumulator3); + } +} + +// +// Implementation of MlasConvPointwiseFloatKernelNeon +// +// This kernel performs pointwise (1x1) convolution which is essentially +// a matrix multiplication across the channel dimension. It's optimized +// for cases where the kernel size is 1x1. +// + +void + MLASCALL + MlasConvPointwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + for (size_t output_idx = 0; output_idx < OutputCount; output_idx++) { + for (size_t f = 0; f < FilterCount; f++) { + const float* filter = Filter + f * FilterStrideElements; + float* output = Output + f * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[f * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t c = 0; c < InputChannels; c++) { + const float* input_ptr = Input + c * InputStrideElements + output_idx * StrideWidthElements; + + for (size_t input_b = 0; input_b < BlockSize; input_b++) { + const float input_value = input_ptr[input_b]; + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + const float* filter_ptr = filter + (c * BlockSize + input_b) * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(filter_ptr); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(filter_ptr + 4); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(filter_ptr + 8); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(filter_ptr + 12); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index f9cf1605787aa..2fc27d6d4ad7f 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp b/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp new file mode 100644 index 0000000000000..8cca036d54c3a --- /dev/null +++ b/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp @@ -0,0 +1,289 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + spool_kernel_neon.cpp + +Abstract: + + This module implements the single precision pooling kernels for ARM NEON. + +--*/ + +#include "mlasi.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +void + MLASCALL + MlasPoolMaximumFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(ActualKernelSize); + MLAS_UNREFERENCED_PARAMETER(InputStride); + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const float MaxPaddingValue = std::numeric_limits::lowest(); + + const MLAS_FLOAT32X4 MaxPaddingVector = MlasBroadcastFloat32x4(MaxPaddingValue); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 MaxVector0 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector1 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector2 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector3 = MaxPaddingVector; + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + } else { + values[i] = MaxPaddingValue; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } + } + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], MaxVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], MaxVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], MaxVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], MaxVector3); + } +} + +static void +MlasPoolAverageFloatKernelNeonImpl( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + bool ExcludePad +) +{ + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const MLAS_FLOAT32X4 ZeroVector = MlasZeroFloat32x4(); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 SumVector0 = ZeroVector; + MLAS_FLOAT32X4 SumVector1 = ZeroVector; + MLAS_FLOAT32X4 SumVector2 = ZeroVector; + MLAS_FLOAT32X4 SumVector3 = ZeroVector; + + std::vector valid_count; + if (ExcludePad) { + valid_count.resize(BlockSize, 0); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + + if (ExcludePad) { + for (size_t i = 0; i < BlockSize; i++) { + valid_count[i]++; + } + } + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + if (ExcludePad) { + valid_count[i]++; + } + } else { + values[i] = 0.0f; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + } + } + } + + if (ExcludePad) { + float results[BlockSize]; + + MlasStoreFloat32x4(&results[0], SumVector0); + MlasStoreFloat32x4(&results[4], SumVector1); + MlasStoreFloat32x4(&results[8], SumVector2); + MlasStoreFloat32x4(&results[12], SumVector3); + + for (size_t i = 0; i < BlockSize; i++) { + results[i] = results[i] / static_cast(valid_count[i]); + } + + MLAS_FLOAT32X4 ResultVector0 = MlasLoadFloat32x4(&results[0]); + MLAS_FLOAT32X4 ResultVector1 = MlasLoadFloat32x4(&results[4]); + MLAS_FLOAT32X4 ResultVector2 = MlasLoadFloat32x4(&results[8]); + MLAS_FLOAT32X4 ResultVector3 = MlasLoadFloat32x4(&results[12]); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } else { + const float KernelSize = static_cast(ActualKernelSize); + const MLAS_FLOAT32X4 KernelSizeVector = MlasBroadcastFloat32x4(KernelSize); + + MLAS_FLOAT32X4 ResultVector0 = MlasDivideFloat32x4(SumVector0, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector1 = MlasDivideFloat32x4(SumVector1, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector2 = MlasDivideFloat32x4(SumVector2, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector3 = MlasDivideFloat32x4(SumVector3, KernelSizeVector); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } + } +} + +void + MLASCALL + MlasPoolAverageExcludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + true // ExcludePad = true + ); +} + +void + MLASCALL + MlasPoolAverageIncludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + false // ExcludePad = false + ); +}