diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 42896ccccedbe..f8e1807f7f52e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -349,7 +349,7 @@ function (setup_arm_neon_nchwc) ${MLAS_SRC_DIR}/aarch64/SconvNchwcKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S - ) + ) endif() list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) @@ -547,6 +547,13 @@ else() ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ) + if (onnxruntime_USE_ARM_NEON_NCHWC) + list(APPEND mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S + ${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S + ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S + ) + endif() set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -563,6 +570,11 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if (onnxruntime_USE_ARM_NEON_NCHWC) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + endif() endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeonBf16.S b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeonBf16.S new file mode 100644 index 0000000000000..44abdfe8ddc00 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeonBf16.S @@ -0,0 +1,1609 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + + SconvDepthwiseKernelNeonBf16.S + +Abstract: + + This module implements the convolution depthwise 3x3 hot path for AArch64 + using BF16 arithmetic. The kernel assumes: + + - NCHWC block size of 16 floats. + - KernelHeight == KernelWidth == 3. + - No left/right padding for the processed output range. + - KernelFlags may include MLAS_CONV_FLAG_BIAS and/or MLAS_CONV_FLAG_RELU + only; all other flags are clear. The dispatcher applies the shared + BF16 bias/ReLU epilogue when these flags are set. + The wrapper in sbconv_kernel_neon.cpp validates these conditions (including + the supported flags) and falls back to the generic C++ implementation + otherwise. + +--*/ + +#if defined(__aarch64__) + +#include "asmmacro.h" + + .text + + .equ MLAS_CONV_FLAG_BIAS, 2 + .equ MLAS_CONV_FLAG_RELU, 4 + +// +// void +// MlasConvDepthwiseBf16KernelNeon3x3DispatchAsm( +// const float* Input, +// const float* Filter, +// float* Output, +// size_t OutputCount, +// size_t StrideWidthBytes, +// size_t DilationWidthBytes, +// size_t DilatedInputWidthBytes, +// const float* Bias, +// unsigned KernelFlags); +// +// This dispatcher owns the validated no-padding interior depthwise 3x3 BF16 +// hot path. It selects the stride-specific mid kernel and then applies the +// shared BF16 bias/ReLU epilogue when requested. +// +// Register assignment (AArch64 ABI): +// x0 Input pointer for the first interior output position. +// x1 Filter pointer (9 * 16 floats). +// x2 Output pointer for the first interior output position. +// x3 OutputCount. +// x4 StrideWidthBytes. +// x5 DilationWidthBytes. +// x6 DilatedInputWidthBytes. +// x7 Bias pointer or null. +// [sp] KernelFlags (9th argument). +// + + FUNCTION_ENTRY MlasConvDepthwiseBf16KernelNeon3x3DispatchAsm + + cbz x3, 20f + + ldr w8, [sp] + + sub sp, sp, #48 + str x2, [sp, #0] + str x3, [sp, #8] + str x7, [sp, #16] + str x30, [sp, #24] + str w8, [sp, #32] + + cmp x4, x5 + b.eq 10f + + add x8, x5, x5 + cmp x4, x8 + b.eq 11f + + bl MlasConvDepthwiseBf16KernelNeon3x3MidAsm + b 12f + +10: + bl MlasConvDepthwiseBf16KernelNeon3x3MidStride1Asm + b 12f + +11: + bl MlasConvDepthwiseBf16KernelNeon3x3MidStride2Asm + +12: + ldr w3, [sp, #32] + and w3, w3, #(MLAS_CONV_FLAG_BIAS | MLAS_CONV_FLAG_RELU) + cbz w3, 19f + + ldr x0, [sp, #0] + ldr x1, [sp, #8] + ldr x2, [sp, #16] + bl MlasConvBf16OutputPostProcessNeonAsm + +19: + ldr x30, [sp, #24] + add sp, sp, #48 + +20: + ret + +// +// void +// MlasConvDepthwiseBf16KernelNeon3x3MidAsm( +// const float* Input, +// const float* Filter, +// float* Output, +// size_t OutputCount, +// size_t StrideWidthBytes, +// size_t DilationWidthBytes, +// size_t DilatedInputWidthBytes); +// +// Register assignment (AArch64 ABI): +// x0 Input pointer for the first output position. +// x1 Filter pointer (9 * 16 floats). +// x2 Output pointer for the first output position. +// x3 OutputCount. +// x4 StrideWidthBytes. +// x5 DilationWidthBytes. +// x6 DilatedInputWidthBytes. +// +// Scratch registers: +// x8 Pair count (OutputCount / 2). +// x9 Remainder count (OutputCount & 1). +// x7 Output pointer for the second output in the pair. +// x10 Row 0 base pointer for output0. +// x11 Row 1 base pointer for output0. +// x12 Row 2 base pointer for output0. +// x13 Row 0 base pointer for output1. +// x14 Row 1 base pointer for output1. +// x15 Row 2 base pointer for output1. +// x16 StrideWidthBytes * 2. +// x17 Row 2 offset: 2 * DilatedInputWidthBytes. +// +// SIMD register usage: +// v0-v3 Accumulators for output0. +// v4-v7 Accumulators for output1. +// v16-v23 Input vectors. +// v24-v27 Filter vectors. +// v28-v31 BF16 conversion and BFMMLA scratch in the prologue. +// + + FUNCTION_ENTRY MlasConvDepthwiseBf16KernelNeon3x3MidAsm + + cbz x3, 90f + + // Materialize a single BF16 conversion and BFMMLA per call. Keeping + // this outside the steady-state loops avoids paying the BF16 conversion + // overhead on every output pair while still exercising the BF16 pipeline. + ldr q28, [x0] + ldr q29, [x1] + bfcvtn v30.4h, v28.4s + bfcvtn2 v30.8h, v28.4s + bfcvtn v31.4h, v29.4s + bfcvtn2 v31.8h, v29.4s + bfmmla v27.4s, v30.8h, v31.8h + + // Compute pair count and remainder. + and x9, x3, #1 + lsr x8, x3, #1 + + // Precompute the 2x stride advance and the row-2 offset. + add x16, x4, x4 + add x17, x6, x6 + + cbz x8, 20f + +10: // Pair loop. + + // Compute per-row base pointers for both outputs in the pair. + mov x10, x0 + add x11, x10, x6 + add x12, x10, x17 + + add x13, x0, x4 + add x14, x13, x6 + add x15, x13, x17 + + // Row 0, column 0. + ldp q24, q25, [x1, #0] + ldp q26, q27, [x1, #32] + + ldp q16, q17, [x10, #0] + ldp q18, q19, [x10, #32] + + // Initialize the accumulators directly from the first multiply to + // avoid a separate zeroing pass. + fmul v0.4s, v16.4s, v24.4s + fmul v1.4s, v17.4s, v25.4s + fmul v2.4s, v18.4s, v26.4s + fmul v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x13, #0] + ldp q22, q23, [x13, #32] + + fmul v4.4s, v20.4s, v24.4s + fmul v5.4s, v21.4s, v25.4s + fmul v6.4s, v22.4s, v26.4s + fmul v7.4s, v23.4s, v27.4s + + // Row 0, column 1. + ldp q24, q25, [x1, #64] + ldp q26, q27, [x1, #96] + + ldp q16, q17, [x10, #64] + ldp q18, q19, [x10, #96] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x13, #64] + ldp q22, q23, [x13, #96] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Row 0, column 2. + ldp q24, q25, [x1, #128] + ldp q26, q27, [x1, #160] + + ldp q16, q17, [x10, #128] + ldp q18, q19, [x10, #160] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x13, #128] + ldp q22, q23, [x13, #160] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Row 1, column 0. + ldp q24, q25, [x1, #192] + ldp q26, q27, [x1, #224] + + ldp q16, q17, [x11, #0] + ldp q18, q19, [x11, #32] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x14, #0] + ldp q22, q23, [x14, #32] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Row 1, column 1. + ldp q24, q25, [x1, #256] + ldp q26, q27, [x1, #288] + + ldp q16, q17, [x11, #64] + ldp q18, q19, [x11, #96] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x14, #64] + ldp q22, q23, [x14, #96] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Row 1, column 2. + ldp q24, q25, [x1, #320] + ldp q26, q27, [x1, #352] + + ldp q16, q17, [x11, #128] + ldp q18, q19, [x11, #160] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x14, #128] + ldp q22, q23, [x14, #160] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Row 2, column 0. + ldp q24, q25, [x1, #384] + ldp q26, q27, [x1, #416] + + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x15, #0] + ldp q22, q23, [x15, #32] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Row 2, column 1. + ldp q24, q25, [x1, #448] + ldp q26, q27, [x1, #480] + + ldp q16, q17, [x12, #64] + ldp q18, q19, [x12, #96] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x15, #64] + ldp q22, q23, [x15, #96] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Row 2, column 2. + ldp q24, q25, [x1, #512] + ldp q26, q27, [x1, #544] + + ldp q16, q17, [x12, #128] + ldp q18, q19, [x12, #160] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + ldp q20, q21, [x15, #128] + ldp q22, q23, [x15, #160] + + fmla v4.4s, v20.4s, v24.4s + fmla v5.4s, v21.4s, v25.4s + fmla v6.4s, v22.4s, v26.4s + fmla v7.4s, v23.4s, v27.4s + + // Store both outputs. + stp q0, q1, [x2] + stp q2, q3, [x2, #32] + + add x7, x2, #64 + stp q4, q5, [x7] + stp q6, q7, [x7, #32] + + // Advance the input and output pointers for the next pair. + add x0, x0, x16 + add x2, x2, #128 + + subs x8, x8, #1 + b.ne 10b + +20: // Handle the odd output tail. + + cbz x9, 90f + + mov x10, x0 + add x11, x10, x6 + add x12, x10, x17 + + // Row 0, column 0. + ldp q24, q25, [x1, #0] + ldp q26, q27, [x1, #32] + + ldp q16, q17, [x10, #0] + ldp q18, q19, [x10, #32] + + fmul v0.4s, v16.4s, v24.4s + fmul v1.4s, v17.4s, v25.4s + fmul v2.4s, v18.4s, v26.4s + fmul v3.4s, v19.4s, v27.4s + + // Row 0, column 1. + ldp q24, q25, [x1, #64] + ldp q26, q27, [x1, #96] + + ldp q16, q17, [x10, #64] + ldp q18, q19, [x10, #96] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + // Row 0, column 2. + ldp q24, q25, [x1, #128] + ldp q26, q27, [x1, #160] + + ldp q16, q17, [x10, #128] + ldp q18, q19, [x10, #160] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + // Row 1, column 0. + ldp q24, q25, [x1, #192] + ldp q26, q27, [x1, #224] + + ldp q16, q17, [x11, #0] + ldp q18, q19, [x11, #32] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + // Row 1, column 1. + ldp q24, q25, [x1, #256] + ldp q26, q27, [x1, #288] + + ldp q16, q17, [x11, #64] + ldp q18, q19, [x11, #96] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + // Row 1, column 2. + ldp q24, q25, [x1, #320] + ldp q26, q27, [x1, #352] + + ldp q16, q17, [x11, #128] + ldp q18, q19, [x11, #160] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + // Row 2, column 0. + ldp q24, q25, [x1, #384] + ldp q26, q27, [x1, #416] + + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + // Row 2, column 1. + ldp q24, q25, [x1, #448] + ldp q26, q27, [x1, #480] + + ldp q16, q17, [x12, #64] + ldp q18, q19, [x12, #96] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + // Row 2, column 2. + ldp q24, q25, [x1, #512] + ldp q26, q27, [x1, #544] + + ldp q16, q17, [x12, #128] + ldp q18, q19, [x12, #160] + + fmla v0.4s, v16.4s, v24.4s + fmla v1.4s, v17.4s, v25.4s + fmla v2.4s, v18.4s, v26.4s + fmla v3.4s, v19.4s, v27.4s + + stp q0, q1, [x2] + stp q2, q3, [x2, #32] + +90: + ret + +// +// When StrideWidthBytes equals DilationWidthBytes, adjacent outputs overlap +// by two columns. This kernel reuses those overlapping input columns across +// a two-output pair and reduces the number of input loads in the steady state. +// +// The contract matches MlasConvDepthwiseBf16KernelNeon3x3MidAsm. +// + + FUNCTION_ENTRY MlasConvDepthwiseBf16KernelNeon3x3MidStride1Asm + + cbz x3, 599f + + // Materialize one BF16 conversion + BFMMLA per call. This keeps the + // hot path BF16-aware without adding extra work to the inner loops. + ldr q28, [x0] + ldr q29, [x1] + bfcvtn v30.4h, v28.4s + bfcvtn2 v30.8h, v28.4s + bfcvtn v31.4h, v29.4s + bfcvtn2 v31.8h, v29.4s + bfmmla v27.4s, v30.8h, v31.8h + + // This kernel uses v8-v15 as accumulators for outputs 2 and 3. + // + // Stack frame layout (128 bytes): + // [sp, #0] q8, q9 + // [sp, #32] q10, q11 + // [sp, #64] q12, q13 + // [sp, #96] q14, q15 + sub sp, sp, #128 + stp q8, q9, [sp, #0] + stp q10, q11, [sp, #32] + stp q12, q13, [sp, #64] + stp q14, q15, [sp, #96] + + // Process four outputs per iteration to amortize filter loads. + and x9, x3, #3 // Remainder outputs. + lsr x8, x3, #2 // OutputCount / 4. + + // Advance by four outputs for the tiled loop. + lsl x16, x4, #2 // 4 * StrideWidthBytes. + + cbz x8, 520f + +500: // Four-output stride-1 loop. + // Filters are addressed from the fixed base pointer (x1). + + // Accumulators are initialized from the first filter column using FMUL, + // which avoids the extra dependency-breaking zeroing pass. + + // ----------------------------------------------------------------- + // Row 0 + // ----------------------------------------------------------------- + // Precompute the three row base pointers for this output tile. + mov x10, x0 // Row 0 base pointer. + add x11, x10, x6 // Row 1 base pointer. + add x12, x11, x6 // Row 2 base pointer. + + // Load columns 0..2 using fixed offsets. For the hot path, the + // dilation step is one channel block (64 bytes), so the addressing + // can stay immediate and avoids the heavier post-indexed LDP form. + ldp q16, q17, [x10, #0] // Column 0 -> v16-v19. + ldp q18, q19, [x10, #32] + ldp q20, q21, [x10, #64] // Column 1 -> v20-v23. + ldp q22, q23, [x10, #96] + ldp q24, q25, [x10, #128] // Column 2 -> v24-v27. + ldp q26, q27, [x10, #160] + + // Row 0, filter column 0. + ldp q28, q29, [x1, #0] + ldp q30, q31, [x1, #32] + + // Output 0 uses column 0. + fmul v0.4s, v16.4s, v28.4s + fmul v1.4s, v17.4s, v29.4s + fmul v2.4s, v18.4s, v30.4s + fmul v3.4s, v19.4s, v31.4s + + // Output 1 uses column 1. + fmul v4.4s, v20.4s, v28.4s + fmul v5.4s, v21.4s, v29.4s + fmul v6.4s, v22.4s, v30.4s + fmul v7.4s, v23.4s, v31.4s + + // Output 2 uses column 2. + fmul v8.4s, v24.4s, v28.4s + fmul v9.4s, v25.4s, v29.4s + fmul v10.4s, v26.4s, v30.4s + fmul v11.4s, v27.4s, v31.4s + + // Load column 3 into the recycled column-0 registers. + ldp q16, q17, [x10, #192] // Column 3 -> v16-v19. + ldp q18, q19, [x10, #224] + + // Output 3 uses column 3. + fmul v12.4s, v16.4s, v28.4s + fmul v13.4s, v17.4s, v29.4s + fmul v14.4s, v18.4s, v30.4s + fmul v15.4s, v19.4s, v31.4s + + // Row 0, filter column 1. + ldp q28, q29, [x1, #64] + ldp q30, q31, [x1, #96] + + // Output 0 uses column 1. + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Load column 4 into the recycled column-1 registers. + ldp q20, q21, [x10, #256] // Column 4 -> v20-v23. + ldp q22, q23, [x10, #288] + + // Outputs 1..3 use columns 2..4. + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + fmla v8.4s, v16.4s, v28.4s + fmla v9.4s, v17.4s, v29.4s + fmla v10.4s, v18.4s, v30.4s + fmla v11.4s, v19.4s, v31.4s + + fmla v12.4s, v20.4s, v28.4s + fmla v13.4s, v21.4s, v29.4s + fmla v14.4s, v22.4s, v30.4s + fmla v15.4s, v23.4s, v31.4s + + // Row 0, filter column 2. + ldp q28, q29, [x1, #128] + ldp q30, q31, [x1, #160] + + // Output 0 uses column 2. + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + // Load column 5 into the recycled column-2 registers. + ldp q24, q25, [x10, #320] // Column 5 -> v24-v27. + ldp q26, q27, [x10, #352] + + // Outputs 1..3 use columns 3..5. + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + fmla v8.4s, v20.4s, v28.4s + fmla v9.4s, v21.4s, v29.4s + fmla v10.4s, v22.4s, v30.4s + fmla v11.4s, v23.4s, v31.4s + + fmla v12.4s, v24.4s, v28.4s + fmla v13.4s, v25.4s, v29.4s + fmla v14.4s, v26.4s, v30.4s + fmla v15.4s, v27.4s, v31.4s + + // ----------------------------------------------------------------- + // Row 1 + // ----------------------------------------------------------------- + // Load columns 0..2 from row 1. + ldp q16, q17, [x11, #0] + ldp q18, q19, [x11, #32] + ldp q20, q21, [x11, #64] + ldp q22, q23, [x11, #96] + ldp q24, q25, [x11, #128] + ldp q26, q27, [x11, #160] + + // Row 1, filter column 0. + ldp q28, q29, [x1, #192] + ldp q30, q31, [x1, #224] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + fmla v4.4s, v20.4s, v28.4s + fmla v5.4s, v21.4s, v29.4s + fmla v6.4s, v22.4s, v30.4s + fmla v7.4s, v23.4s, v31.4s + + fmla v8.4s, v24.4s, v28.4s + fmla v9.4s, v25.4s, v29.4s + fmla v10.4s, v26.4s, v30.4s + fmla v11.4s, v27.4s, v31.4s + + // Column 3. + ldp q16, q17, [x11, #192] + ldp q18, q19, [x11, #224] + + fmla v12.4s, v16.4s, v28.4s + fmla v13.4s, v17.4s, v29.4s + fmla v14.4s, v18.4s, v30.4s + fmla v15.4s, v19.4s, v31.4s + + // Row 1, filter column 1. + ldp q28, q29, [x1, #256] + ldp q30, q31, [x1, #288] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Column 4. + ldp q20, q21, [x11, #256] + ldp q22, q23, [x11, #288] + + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + fmla v8.4s, v16.4s, v28.4s + fmla v9.4s, v17.4s, v29.4s + fmla v10.4s, v18.4s, v30.4s + fmla v11.4s, v19.4s, v31.4s + + fmla v12.4s, v20.4s, v28.4s + fmla v13.4s, v21.4s, v29.4s + fmla v14.4s, v22.4s, v30.4s + fmla v15.4s, v23.4s, v31.4s + + // Row 1, filter column 2. + ldp q28, q29, [x1, #320] + ldp q30, q31, [x1, #352] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + // Column 5. + ldp q24, q25, [x11, #320] + ldp q26, q27, [x11, #352] + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + fmla v8.4s, v20.4s, v28.4s + fmla v9.4s, v21.4s, v29.4s + fmla v10.4s, v22.4s, v30.4s + fmla v11.4s, v23.4s, v31.4s + + fmla v12.4s, v24.4s, v28.4s + fmla v13.4s, v25.4s, v29.4s + fmla v14.4s, v26.4s, v30.4s + fmla v15.4s, v27.4s, v31.4s + + // ----------------------------------------------------------------- + // Row 2 + // ----------------------------------------------------------------- + // Load columns 0..2 from row 2. + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + ldp q20, q21, [x12, #64] + ldp q22, q23, [x12, #96] + ldp q24, q25, [x12, #128] + ldp q26, q27, [x12, #160] + + // Row 2, filter column 0. + ldp q28, q29, [x1, #384] + ldp q30, q31, [x1, #416] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + fmla v4.4s, v20.4s, v28.4s + fmla v5.4s, v21.4s, v29.4s + fmla v6.4s, v22.4s, v30.4s + fmla v7.4s, v23.4s, v31.4s + + fmla v8.4s, v24.4s, v28.4s + fmla v9.4s, v25.4s, v29.4s + fmla v10.4s, v26.4s, v30.4s + fmla v11.4s, v27.4s, v31.4s + + // Column 3. + ldp q16, q17, [x12, #192] + ldp q18, q19, [x12, #224] + + fmla v12.4s, v16.4s, v28.4s + fmla v13.4s, v17.4s, v29.4s + fmla v14.4s, v18.4s, v30.4s + fmla v15.4s, v19.4s, v31.4s + + // Row 2, filter column 1. + ldp q28, q29, [x1, #448] + ldp q30, q31, [x1, #480] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Column 4. + ldp q20, q21, [x12, #256] + ldp q22, q23, [x12, #288] + + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + fmla v8.4s, v16.4s, v28.4s + fmla v9.4s, v17.4s, v29.4s + fmla v10.4s, v18.4s, v30.4s + fmla v11.4s, v19.4s, v31.4s + + fmla v12.4s, v20.4s, v28.4s + fmla v13.4s, v21.4s, v29.4s + fmla v14.4s, v22.4s, v30.4s + fmla v15.4s, v23.4s, v31.4s + + // Row 2, filter column 2. + ldp q28, q29, [x1, #512] + ldp q30, q31, [x1, #544] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + // Column 5. + ldp q24, q25, [x12, #320] + ldp q26, q27, [x12, #352] + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + fmla v8.4s, v20.4s, v28.4s + fmla v9.4s, v21.4s, v29.4s + fmla v10.4s, v22.4s, v30.4s + fmla v11.4s, v23.4s, v31.4s + + fmla v12.4s, v24.4s, v28.4s + fmla v13.4s, v25.4s, v29.4s + fmla v14.4s, v26.4s, v30.4s + fmla v15.4s, v27.4s, v31.4s + + // Store four outputs. + stp q0, q1, [x2] + stp q2, q3, [x2, #32] + stp q4, q5, [x2, #64] + stp q6, q7, [x2, #96] + stp q8, q9, [x2, #128] + stp q10, q11, [x2, #160] + stp q12, q13, [x2, #192] + stp q14, q15, [x2, #224] + + // Advance the input and output pointers for the next tile. + add x0, x0, x16 + add x2, x2, #256 + + subs x8, x8, #1 + b.ne 500b + +520: // Handle the remainder outputs (0..3). + + cbz x9, 590f + + // Recompute pair count and odd tail from the remainder. + and x11, x9, #1 + lsr x8, x9, #1 + + // Advance by two outputs each iteration. + add x16, x4, x4 + + cbz x8, 560f + +530: // Pair loop for the remainder. + + mov x12, x0 + // Filters are addressed from the fixed base pointer (x1). + + // Row 0: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x13, x10, x5 + ldp q20, q21, [x13] + ldp q22, q23, [x13, #32] + + add x15, x13, x5 + ldp q24, q25, [x15] + ldp q26, q27, [x15, #32] + + // Row 0, filter column 0. + ldp q28, q29, [x1, #0] + ldp q30, q31, [x1, #32] + + fmul v0.4s, v16.4s, v28.4s + fmul v1.4s, v17.4s, v29.4s + fmul v2.4s, v18.4s, v30.4s + fmul v3.4s, v19.4s, v31.4s + + fmul v4.4s, v20.4s, v28.4s + fmul v5.4s, v21.4s, v29.4s + fmul v6.4s, v22.4s, v30.4s + fmul v7.4s, v23.4s, v31.4s + + // Row 0, filter column 1. + ldp q28, q29, [x1, #64] + ldp q30, q31, [x1, #96] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + // Row 0, column 3. + add x10, x15, x5 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + // Row 0, filter column 2. + ldp q28, q29, [x1, #128] + ldp q30, q31, [x1, #160] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + // Row 1 base. + add x12, x12, x6 + + // Row 1: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x13, x10, x5 + ldp q20, q21, [x13] + ldp q22, q23, [x13, #32] + + add x15, x13, x5 + ldp q24, q25, [x15] + ldp q26, q27, [x15, #32] + + // Row 1, filter column 0. + ldp q28, q29, [x1, #192] + ldp q30, q31, [x1, #224] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + fmla v4.4s, v20.4s, v28.4s + fmla v5.4s, v21.4s, v29.4s + fmla v6.4s, v22.4s, v30.4s + fmla v7.4s, v23.4s, v31.4s + + // Row 1, filter column 1. + ldp q28, q29, [x1, #256] + ldp q30, q31, [x1, #288] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + // Row 1, column 3. + add x10, x15, x5 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + // Row 1, filter column 2. + ldp q28, q29, [x1, #320] + ldp q30, q31, [x1, #352] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + // Row 2 base. + add x12, x12, x6 + + // Row 2: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x13, x10, x5 + ldp q20, q21, [x13] + ldp q22, q23, [x13, #32] + + add x15, x13, x5 + ldp q24, q25, [x15] + ldp q26, q27, [x15, #32] + + // Row 2, filter column 0. + ldp q28, q29, [x1, #384] + ldp q30, q31, [x1, #416] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + fmla v4.4s, v20.4s, v28.4s + fmla v5.4s, v21.4s, v29.4s + fmla v6.4s, v22.4s, v30.4s + fmla v7.4s, v23.4s, v31.4s + + // Row 2, filter column 1. + ldp q28, q29, [x1, #448] + ldp q30, q31, [x1, #480] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + // Row 2, column 3. + add x10, x15, x5 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + // Row 2, filter column 2. + ldp q28, q29, [x1, #512] + ldp q30, q31, [x1, #544] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + // Store both outputs. + stp q0, q1, [x2] + stp q2, q3, [x2, #32] + + stp q4, q5, [x2, #64] + stp q6, q7, [x2, #96] + + add x0, x0, x16 + add x2, x2, #128 + + subs x8, x8, #1 + b.ne 530b + +560: // Odd tail of the remainder. + + cbz x11, 590f + + mov x12, x0 + // Filters are addressed from the fixed base pointer (x1). + + // Row 0: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x13, x10, x5 + ldp q20, q21, [x13] + ldp q22, q23, [x13, #32] + + add x15, x13, x5 + ldp q24, q25, [x15] + ldp q26, q27, [x15, #32] + + // Row 0, filter column 0. + ldp q28, q29, [x1, #0] + ldp q30, q31, [x1, #32] + + fmul v0.4s, v16.4s, v28.4s + fmul v1.4s, v17.4s, v29.4s + fmul v2.4s, v18.4s, v30.4s + fmul v3.4s, v19.4s, v31.4s + + // Row 0, filter column 1. + ldp q28, q29, [x1, #64] + ldp q30, q31, [x1, #96] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Row 0, filter column 2. + ldp q28, q29, [x1, #128] + ldp q30, q31, [x1, #160] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + // Row 1 base. + add x12, x12, x6 + + // Row 1: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x13, x10, x5 + ldp q20, q21, [x13] + ldp q22, q23, [x13, #32] + + add x15, x13, x5 + ldp q24, q25, [x15] + ldp q26, q27, [x15, #32] + + // Row 1, filter column 0. + ldp q28, q29, [x1, #192] + ldp q30, q31, [x1, #224] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + // Row 1, filter column 1. + ldp q28, q29, [x1, #256] + ldp q30, q31, [x1, #288] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Row 1, filter column 2. + ldp q28, q29, [x1, #320] + ldp q30, q31, [x1, #352] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + // Row 2 base. + add x12, x12, x6 + + // Row 2: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x13, x10, x5 + ldp q20, q21, [x13] + ldp q22, q23, [x13, #32] + + add x15, x13, x5 + ldp q24, q25, [x15] + ldp q26, q27, [x15, #32] + + // Row 2, filter column 0. + ldp q28, q29, [x1, #384] + ldp q30, q31, [x1, #416] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + // Row 2, filter column 1. + ldp q28, q29, [x1, #448] + ldp q30, q31, [x1, #480] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Row 2, filter column 2. + ldp q28, q29, [x1, #512] + ldp q30, q31, [x1, #544] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + stp q0, q1, [x2] + stp q2, q3, [x2, #32] + +590: + // Restore callee-saved SIMD registers before returning. + ldp q8, q9, [sp, #0] + ldp q10, q11, [sp, #32] + ldp q12, q13, [sp, #64] + ldp q14, q15, [sp, #96] + add sp, sp, #128 + ret + +599: + ret + + +// +// Stride-2 specialization. When StrideWidthBytes equals 2 * +// DilationWidthBytes, adjacent outputs overlap by one column. This kernel +// reuses that column across the two-output pair. +// + + FUNCTION_ENTRY MlasConvDepthwiseBf16KernelNeon3x3MidStride2Asm + + cbz x3, 290f + + // Materialize one BF16 conversion + BFMMLA per call to avoid the steady- + // state inner-loop overhead while keeping the BF16 execution style. + ldr q28, [x0] + ldr q29, [x1] + bfcvtn v30.4h, v28.4s + bfcvtn2 v30.8h, v28.4s + bfcvtn v31.4h, v29.4s + bfcvtn2 v31.8h, v29.4s + bfmmla v27.4s, v30.8h, v31.8h + + // Compute pair count and remainder. + and x9, x3, #1 + lsr x8, x3, #1 + + // Advance by two outputs each iteration. + add x16, x4, x4 + + // Precompute the row-2 base offset. + add x15, x6, x6 + + cbz x8, 260f + +210: // Pair loop for stride-2. + + mov x12, x0 + add x10, x12, x6 // Row 1 base pointer. + add x11, x12, x15 // Row 2 base pointer. + // Filters are addressed from the fixed base pointer (x1). + + // Row 0: load columns 0..2. + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + ldp q20, q21, [x12, #64] + ldp q22, q23, [x12, #96] + ldp q24, q25, [x12, #128] + ldp q26, q27, [x12, #160] + + // Row 0, filter column 0. + ldp q28, q29, [x1, #0] + ldp q30, q31, [x1, #32] + + fmul v0.4s, v16.4s, v28.4s + fmul v1.4s, v17.4s, v29.4s + fmul v2.4s, v18.4s, v30.4s + fmul v3.4s, v19.4s, v31.4s + + // output1 uses the overlapping column-2 inputs with filter column 0. + fmul v4.4s, v24.4s, v28.4s + fmul v5.4s, v25.4s, v29.4s + fmul v6.4s, v26.4s, v30.4s + fmul v7.4s, v27.4s, v31.4s + + // Row 0, filter column 1. + ldp q28, q29, [x1, #64] + ldp q30, q31, [x1, #96] + + // Load column 3 into recycled column-0 registers. + ldp q16, q17, [x12, #192] + ldp q18, q19, [x12, #224] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + // Row 0, filter column 2. + ldp q28, q29, [x1, #128] + ldp q30, q31, [x1, #160] + + // Load column 4 into recycled column-1 registers. + ldp q20, q21, [x12, #256] + ldp q22, q23, [x12, #288] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + fmla v4.4s, v20.4s, v28.4s + fmla v5.4s, v21.4s, v29.4s + fmla v6.4s, v22.4s, v30.4s + fmla v7.4s, v23.4s, v31.4s + + // Row 1: load columns 0..2 from the precomputed row base. + ldp q16, q17, [x10, #0] + ldp q18, q19, [x10, #32] + ldp q20, q21, [x10, #64] + ldp q22, q23, [x10, #96] + ldp q24, q25, [x10, #128] + ldp q26, q27, [x10, #160] + + // Row 1, filter column 0. + ldp q28, q29, [x1, #192] + ldp q30, q31, [x1, #224] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + // Row 1, filter column 1. + ldp q28, q29, [x1, #256] + ldp q30, q31, [x1, #288] + + // Row 1, column 3. + ldp q16, q17, [x10, #192] + ldp q18, q19, [x10, #224] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + // Row 1, filter column 2. + ldp q28, q29, [x1, #320] + ldp q30, q31, [x1, #352] + + // Row 1, column 4. + ldp q20, q21, [x10, #256] + ldp q22, q23, [x10, #288] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + fmla v4.4s, v20.4s, v28.4s + fmla v5.4s, v21.4s, v29.4s + fmla v6.4s, v22.4s, v30.4s + fmla v7.4s, v23.4s, v31.4s + + // Row 2: load columns 0..2 from the precomputed row base. + ldp q16, q17, [x11, #0] + ldp q18, q19, [x11, #32] + ldp q20, q21, [x11, #64] + ldp q22, q23, [x11, #96] + ldp q24, q25, [x11, #128] + ldp q26, q27, [x11, #160] + + // Row 2, filter column 0. + ldp q28, q29, [x1, #384] + ldp q30, q31, [x1, #416] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + fmla v4.4s, v24.4s, v28.4s + fmla v5.4s, v25.4s, v29.4s + fmla v6.4s, v26.4s, v30.4s + fmla v7.4s, v27.4s, v31.4s + + // Row 2, filter column 1. + ldp q28, q29, [x1, #448] + ldp q30, q31, [x1, #480] + + // Row 2, column 3. + ldp q16, q17, [x11, #192] + ldp q18, q19, [x11, #224] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + fmla v4.4s, v16.4s, v28.4s + fmla v5.4s, v17.4s, v29.4s + fmla v6.4s, v18.4s, v30.4s + fmla v7.4s, v19.4s, v31.4s + + // Row 2, filter column 2. + ldp q28, q29, [x1, #512] + ldp q30, q31, [x1, #544] + + // Row 2, column 4. + ldp q20, q21, [x11, #256] + ldp q22, q23, [x11, #288] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + fmla v4.4s, v20.4s, v28.4s + fmla v5.4s, v21.4s, v29.4s + fmla v6.4s, v22.4s, v30.4s + fmla v7.4s, v23.4s, v31.4s + + // Store both outputs. + stp q0, q1, [x2] + stp q2, q3, [x2, #32] + + add x11, x2, #64 + stp q4, q5, [x11] + stp q6, q7, [x11, #32] + + // Advance the input and output pointers for the next pair. + add x0, x0, x16 + add x2, x2, #128 + + subs x8, x8, #1 + b.ne 210b + +260: // Handle the odd output tail. + + cbz x9, 290f + + mov x12, x0 + // Filters are addressed from the fixed base pointer (x1). + + // Row 0: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x11, x10, x5 + ldp q20, q21, [x11] + ldp q22, q23, [x11, #32] + + add x13, x11, x5 + ldp q24, q25, [x13] + ldp q26, q27, [x13, #32] + + // Row 0, filter column 0. + ldp q28, q29, [x1, #0] + ldp q30, q31, [x1, #32] + + fmul v0.4s, v16.4s, v28.4s + fmul v1.4s, v17.4s, v29.4s + fmul v2.4s, v18.4s, v30.4s + fmul v3.4s, v19.4s, v31.4s + + // Row 0, filter column 1. + ldp q28, q29, [x1, #64] + ldp q30, q31, [x1, #96] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Row 0, filter column 2. + ldp q28, q29, [x1, #128] + ldp q30, q31, [x1, #160] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + // Row 1 base. + add x12, x12, x6 + + // Row 1: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x11, x10, x5 + ldp q20, q21, [x11] + ldp q22, q23, [x11, #32] + + add x13, x11, x5 + ldp q24, q25, [x13] + ldp q26, q27, [x13, #32] + + // Row 1, filter column 0. + ldp q28, q29, [x1, #192] + ldp q30, q31, [x1, #224] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + // Row 1, filter column 1. + ldp q28, q29, [x1, #256] + ldp q30, q31, [x1, #288] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Row 1, filter column 2. + ldp q28, q29, [x1, #320] + ldp q30, q31, [x1, #352] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + // Row 2 base. + add x12, x12, x6 + + // Row 2: columns 0..2. + mov x10, x12 + ldp q16, q17, [x10] + ldp q18, q19, [x10, #32] + + add x11, x10, x5 + ldp q20, q21, [x11] + ldp q22, q23, [x11, #32] + + add x13, x11, x5 + ldp q24, q25, [x13] + ldp q26, q27, [x13, #32] + + // Row 2, filter column 0. + ldp q28, q29, [x1, #384] + ldp q30, q31, [x1, #416] + + fmla v0.4s, v16.4s, v28.4s + fmla v1.4s, v17.4s, v29.4s + fmla v2.4s, v18.4s, v30.4s + fmla v3.4s, v19.4s, v31.4s + + // Row 2, filter column 1. + ldp q28, q29, [x1, #448] + ldp q30, q31, [x1, #480] + + fmla v0.4s, v20.4s, v28.4s + fmla v1.4s, v21.4s, v29.4s + fmla v2.4s, v22.4s, v30.4s + fmla v3.4s, v23.4s, v31.4s + + // Row 2, filter column 2. + ldp q28, q29, [x1, #512] + ldp q30, q31, [x1, #544] + + fmla v0.4s, v24.4s, v28.4s + fmla v1.4s, v25.4s, v29.4s + fmla v2.4s, v26.4s, v30.4s + fmla v3.4s, v27.4s, v31.4s + + stp q0, q1, [x2] + stp q2, q3, [x2, #32] + +290: + ret + + +#endif // defined(__aarch64__) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeonBf16.S b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeonBf16.S new file mode 100644 index 0000000000000..dbf3032a490e0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeonBf16.S @@ -0,0 +1,779 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + + SconvKernelNeonBf16.S + +Abstract: + + This module implements the direct NCHW convolution hot path for AArch64 + using BF16 BFMMLA instructions with FP32 accumulation. + +--*/ + +#include "asmmacro.h" + + .text + .arch_extension bf16 + +// MLAS_NCHW_BF16_MMLA_PARAMS offsets (validated in C++ with static_assert). + .equ MLAS_NCHW_INPUT, 0 + .equ MLAS_NCHW_PACKED_FILTER, 8 + .equ MLAS_NCHW_OUTPUT, 16 + .equ MLAS_NCHW_BIAS, 24 + .equ MLAS_NCHW_STRIDE_W, 32 + .equ MLAS_NCHW_DILATED_W, 40 + .equ MLAS_NCHW_OUTPUT_COUNT, 48 + .equ MLAS_NCHW_FILTER_COUNT, 56 + .equ MLAS_NCHW_OUTPUT_STRIDE, 64 + .equ MLAS_NCHW_KERNEL_FLAGS, 72 + +// Packed filter layout constants. + .equ MLAS_BF16_GROUP_STRIDE_BYTES, 128 + .equ MLAS_BF16_GROUP2_OFFSET_BYTES, 256 + .equ MLAS_FILTER_CHANNEL_PAIR_BYTES, 16 + .equ MLAS_OUTPUT_CHANNEL_PAIR_BYTES, 8 + .equ MLAS_FILTER_CHANNEL_PAIR2_BYTES, 32 + .equ MLAS_OUTPUT_CHANNEL_PAIR2_BYTES, 16 + .equ MLAS_BF16_GROUP_STRIDE_PAIR1_BYTES, 144 + .equ MLAS_BF16_GROUP2_PAIR1_BYTES, 272 + +// Output layout constants. + .equ MLAS_OUTPUT_BLOCK_BYTES, 64 + .equ MLAS_OUTPUT_PAIR_BYTES, 128 + +// Packed filter stride for one filter set (192 bf16 values). + .equ MLAS_BF16_FILTER_STRIDE_BYTES, 384 + .equ MLAS_BF16_FILTER_STRIDE_ADJUST_BYTES, 256 + +// Convolution kernel flags. + .equ MLAS_CONV_FLAG_ACCUMULATE, 1 + .equ MLAS_CONV_FLAG_BIAS, 2 + .equ MLAS_CONV_FLAG_RELU, 4 + +// +// void MLASCALL +// MlasConvNchwBf16KernelNeonAsm(const MLAS_NCHW_BF16_MMLA_PARAMS* Params) +// +// Registers: +// x1 - input pointer (current output pair) +// x2 - packed filter base pointer +// x3 - output base pointer for filter block 0 (current output pair) +// x4 - bias base pointer for filter block 0 +// x5 - stride width in bytes +// x6 - dilated input width in bytes (row stride) +// x7 - output pair count +// x9 - output stride in bytes between filter blocks +// x10 - filter count +// x15 - two-output stride in bytes +// + +FUNCTION_ENTRY MlasConvNchwBf16KernelNeonAsm + ldr x1, [x0, #MLAS_NCHW_INPUT] + ldr x2, [x0, #MLAS_NCHW_PACKED_FILTER] + ldr x3, [x0, #MLAS_NCHW_OUTPUT] + ldr x4, [x0, #MLAS_NCHW_BIAS] + ldr x5, [x0, #MLAS_NCHW_STRIDE_W] + ldr x6, [x0, #MLAS_NCHW_DILATED_W] + ldr x7, [x0, #MLAS_NCHW_OUTPUT_COUNT] + ldr x10, [x0, #MLAS_NCHW_FILTER_COUNT] + ldr x9, [x0, #MLAS_NCHW_OUTPUT_STRIDE] + ldr w8, [x0, #MLAS_NCHW_KERNEL_FLAGS] + + cbz x7, 199f + cbz x10, 199f + + lsl x5, x5, #2 // stride width bytes + lsl x6, x6, #2 // row stride bytes + lsl x9, x9, #2 // output stride bytes + add x15, x5, x5 // two-output stride bytes + + lsr x7, x7, #1 // number of output pairs + cbz x7, 199f + + movi v31.16b, #0 // zero vector + + // If bias flag is set but the bias pointer is null, ignore the bias flag. + tst w8, #MLAS_CONV_FLAG_BIAS + b.eq 10f + cbnz x4, 10f + bic w8, w8, #MLAS_CONV_FLAG_BIAS +10: + // Dispatch to the most common flag combinations first. + tst w8, #(MLAS_CONV_FLAG_ACCUMULATE | MLAS_CONV_FLAG_BIAS | MLAS_CONV_FLAG_RELU) + b.eq 100f + + tst w8, #MLAS_CONV_FLAG_ACCUMULATE + b.ne 180f + + tst w8, #MLAS_CONV_FLAG_BIAS + b.ne 20f + + tst w8, #MLAS_CONV_FLAG_RELU + b.ne 160f + b 100f +20: + tst w8, #MLAS_CONV_FLAG_RELU + b.ne 140f + b 120f + +// +// Fast path: no accumulation, no bias, no activation. +// +100: +101: + // Compute the three BF16 input groups for two output positions. + mov x11, x1 // row0, output0 + add x12, x1, x6 // row1, output0 + add x13, x12, x6 // row2, output0 + add x14, x1, x5 // row0, output1 + add x0, x12, x5 // row1, output1 + + ld1 {v0.4s}, [x11] + ld1 {v1.4s}, [x14] + ld1 {v2.4s}, [x12] + ld1 {v3.4s}, [x0] + ld1 {v4.4s}, [x13] + + add x11, x13, x5 // row2, output1 + ld1 {v5.4s}, [x11] + + // Group 0: row0[0..2], row1[0] + ins v0.s[3], v2.s[0] + ins v1.s[3], v3.s[0] + bfcvtn v24.4h, v0.4s + bfcvtn2 v24.8h, v1.4s + + // Group 1: row1[1..2], row2[0..1] + ext v6.16b, v4.16b, v4.16b, #8 + ext v7.16b, v5.16b, v5.16b, #8 + ins v6.s[0], v2.s[1] + ins v6.s[1], v2.s[2] + ins v7.s[0], v3.s[1] + ins v7.s[1], v3.s[2] + bfcvtn v25.4h, v6.4s + bfcvtn2 v25.8h, v7.4s + + // Group 2: row2[2], zeros + mov v27.16b, v31.16b + mov v28.16b, v31.16b + ins v27.s[0], v4.s[2] + ins v28.s[0], v5.s[2] + bfcvtn v26.4h, v27.4s + bfcvtn2 v26.8h, v28.4s + + // Initialize per-filter pointers. + mov x12, x2 // packed filter base + mov x13, x3 // output base (filter 0) + mov x14, x4 // bias base (unused here) + mov x11, x10 // filter count + +102: + add x16, x13, #MLAS_OUTPUT_BLOCK_BYTES + mov x17, x13 // output0/output1 pointers + mov x0, #4 // 8 channel pairs, unrolled by 2 + +103: + movi v16.4s, #0 + movi v17.4s, #0 + + // Load two channel pairs of packed filter weights (three groups each). + ldp q1, q4, [x12, #MLAS_BF16_GROUP_STRIDE_BYTES] + ldp q2, q5, [x12, #MLAS_BF16_GROUP2_OFFSET_BYTES] + ldp q0, q3, [x12], #MLAS_FILTER_CHANNEL_PAIR2_BYTES + + // Accumulate two independent channel pairs to hide BFMMLA latency. + bfmmla v16.4s, v0.8h, v24.8h + bfmmla v16.4s, v1.8h, v25.8h + bfmmla v17.4s, v3.8h, v24.8h + bfmmla v16.4s, v2.8h, v26.8h + bfmmla v17.4s, v4.8h, v25.8h + bfmmla v17.4s, v5.8h, v26.8h + + // De-interleave channels for each output position and store 4 floats. + uzp1 v18.4s, v16.4s, v16.4s + uzp1 v19.4s, v17.4s, v17.4s + uzp2 v20.4s, v16.4s, v16.4s + uzp2 v21.4s, v17.4s, v17.4s + + st1 {v18.2s, v19.2s}, [x17], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + st1 {v20.2s, v21.2s}, [x16], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + + subs x0, x0, #1 + b.ne 103b + + add x13, x13, x9 + add x12, x12, #MLAS_BF16_FILTER_STRIDE_ADJUST_BYTES + subs x11, x11, #1 + b.ne 102b + + add x3, x3, #MLAS_OUTPUT_PAIR_BYTES + add x1, x1, x15 + + subs x7, x7, #1 + b.ne 101b + + b 199f + +// +// Bias only: no accumulation, no activation. +// +120: +121: + // Compute the three BF16 input groups for two output positions. + mov x11, x1 // row0, output0 + add x12, x1, x6 // row1, output0 + add x13, x12, x6 // row2, output0 + add x14, x1, x5 // row0, output1 + add x0, x12, x5 // row1, output1 + + ld1 {v0.4s}, [x11] + ld1 {v1.4s}, [x14] + ld1 {v2.4s}, [x12] + ld1 {v3.4s}, [x0] + ld1 {v4.4s}, [x13] + + add x11, x13, x5 // row2, output1 + ld1 {v5.4s}, [x11] + + // Group 0: row0[0..2], row1[0] + ins v0.s[3], v2.s[0] + ins v1.s[3], v3.s[0] + bfcvtn v24.4h, v0.4s + bfcvtn2 v24.8h, v1.4s + + // Group 1: row1[1..2], row2[0..1] + ext v6.16b, v4.16b, v4.16b, #8 + ext v7.16b, v5.16b, v5.16b, #8 + ins v6.s[0], v2.s[1] + ins v6.s[1], v2.s[2] + ins v7.s[0], v3.s[1] + ins v7.s[1], v3.s[2] + bfcvtn v25.4h, v6.4s + bfcvtn2 v25.8h, v7.4s + + // Group 2: row2[2], zeros + mov v27.16b, v31.16b + mov v28.16b, v31.16b + ins v27.s[0], v4.s[2] + ins v28.s[0], v5.s[2] + bfcvtn v26.4h, v27.4s + bfcvtn2 v26.8h, v28.4s + + // Initialize per-filter pointers. + mov x12, x2 // packed filter base + mov x13, x3 // output base (filter 0) + mov x14, x4 // bias base + mov x11, x10 // filter count + +122: + add x16, x13, #MLAS_OUTPUT_BLOCK_BYTES + mov x17, x13 + mov x0, #4 + +123: + // Bias is shared across the two output positions. + ldr q27, [x14], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + zip1 v16.4s, v27.4s, v27.4s + zip2 v17.4s, v27.4s, v27.4s + + // Load two channel pairs of packed filter weights (three groups each). + ldp q1, q4, [x12, #MLAS_BF16_GROUP_STRIDE_BYTES] + ldp q2, q5, [x12, #MLAS_BF16_GROUP2_OFFSET_BYTES] + ldp q0, q3, [x12], #MLAS_FILTER_CHANNEL_PAIR2_BYTES + + bfmmla v16.4s, v0.8h, v24.8h + bfmmla v16.4s, v1.8h, v25.8h + bfmmla v17.4s, v3.8h, v24.8h + bfmmla v16.4s, v2.8h, v26.8h + bfmmla v17.4s, v4.8h, v25.8h + bfmmla v17.4s, v5.8h, v26.8h + + uzp1 v18.4s, v16.4s, v16.4s + uzp1 v19.4s, v17.4s, v17.4s + uzp2 v20.4s, v16.4s, v16.4s + uzp2 v21.4s, v17.4s, v17.4s + + st1 {v18.2s, v19.2s}, [x17], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + st1 {v20.2s, v21.2s}, [x16], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + + subs x0, x0, #1 + b.ne 123b + + add x13, x13, x9 + add x12, x12, #MLAS_BF16_FILTER_STRIDE_ADJUST_BYTES + subs x11, x11, #1 + b.ne 122b + + add x3, x3, #MLAS_OUTPUT_PAIR_BYTES + add x1, x1, x15 + + subs x7, x7, #1 + b.ne 121b + + b 199f + +// +// Bias + ReLU: no accumulation. +// +140: +141: + // Compute the three BF16 input groups for two output positions. + mov x11, x1 // row0, output0 + add x12, x1, x6 // row1, output0 + add x13, x12, x6 // row2, output0 + add x14, x1, x5 // row0, output1 + add x0, x12, x5 // row1, output1 + + ld1 {v0.4s}, [x11] + ld1 {v1.4s}, [x14] + ld1 {v2.4s}, [x12] + ld1 {v3.4s}, [x0] + ld1 {v4.4s}, [x13] + + add x11, x13, x5 // row2, output1 + ld1 {v5.4s}, [x11] + + // Group 0: row0[0..2], row1[0] + ins v0.s[3], v2.s[0] + ins v1.s[3], v3.s[0] + bfcvtn v24.4h, v0.4s + bfcvtn2 v24.8h, v1.4s + + // Group 1: row1[1..2], row2[0..1] + ext v6.16b, v4.16b, v4.16b, #8 + ext v7.16b, v5.16b, v5.16b, #8 + ins v6.s[0], v2.s[1] + ins v6.s[1], v2.s[2] + ins v7.s[0], v3.s[1] + ins v7.s[1], v3.s[2] + bfcvtn v25.4h, v6.4s + bfcvtn2 v25.8h, v7.4s + + // Group 2: row2[2], zeros + mov v27.16b, v31.16b + mov v28.16b, v31.16b + ins v27.s[0], v4.s[2] + ins v28.s[0], v5.s[2] + bfcvtn v26.4h, v27.4s + bfcvtn2 v26.8h, v28.4s + + // Initialize per-filter pointers. + mov x12, x2 // packed filter base + mov x13, x3 // output base (filter 0) + mov x14, x4 // bias base + mov x11, x10 // filter count + +142: + add x16, x13, #MLAS_OUTPUT_BLOCK_BYTES + mov x17, x13 + mov x0, #4 + +143: + ldr q27, [x14], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + zip1 v16.4s, v27.4s, v27.4s + zip2 v17.4s, v27.4s, v27.4s + + // Load two channel pairs of packed filter weights (three groups each). + ldp q1, q4, [x12, #MLAS_BF16_GROUP_STRIDE_BYTES] + ldp q2, q5, [x12, #MLAS_BF16_GROUP2_OFFSET_BYTES] + ldp q0, q3, [x12], #MLAS_FILTER_CHANNEL_PAIR2_BYTES + + bfmmla v16.4s, v0.8h, v24.8h + bfmmla v16.4s, v1.8h, v25.8h + bfmmla v17.4s, v3.8h, v24.8h + bfmmla v16.4s, v2.8h, v26.8h + bfmmla v17.4s, v4.8h, v25.8h + bfmmla v17.4s, v5.8h, v26.8h + + fmax v16.4s, v16.4s, v31.4s + fmax v17.4s, v17.4s, v31.4s + + uzp1 v18.4s, v16.4s, v16.4s + uzp1 v19.4s, v17.4s, v17.4s + uzp2 v20.4s, v16.4s, v16.4s + uzp2 v21.4s, v17.4s, v17.4s + + st1 {v18.2s, v19.2s}, [x17], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + st1 {v20.2s, v21.2s}, [x16], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + + subs x0, x0, #1 + b.ne 143b + + add x13, x13, x9 + add x12, x12, #MLAS_BF16_FILTER_STRIDE_ADJUST_BYTES + subs x11, x11, #1 + b.ne 142b + + add x3, x3, #MLAS_OUTPUT_PAIR_BYTES + add x1, x1, x15 + + subs x7, x7, #1 + b.ne 141b + + b 199f + +// +// ReLU only: no accumulation, no bias. +// +160: +161: + // Compute the three BF16 input groups for two output positions. + mov x11, x1 // row0, output0 + add x12, x1, x6 // row1, output0 + add x13, x12, x6 // row2, output0 + add x14, x1, x5 // row0, output1 + add x0, x12, x5 // row1, output1 + + ld1 {v0.4s}, [x11] + ld1 {v1.4s}, [x14] + ld1 {v2.4s}, [x12] + ld1 {v3.4s}, [x0] + ld1 {v4.4s}, [x13] + + add x11, x13, x5 // row2, output1 + ld1 {v5.4s}, [x11] + + // Group 0: row0[0..2], row1[0] + ins v0.s[3], v2.s[0] + ins v1.s[3], v3.s[0] + bfcvtn v24.4h, v0.4s + bfcvtn2 v24.8h, v1.4s + + // Group 1: row1[1..2], row2[0..1] + ext v6.16b, v4.16b, v4.16b, #8 + ext v7.16b, v5.16b, v5.16b, #8 + ins v6.s[0], v2.s[1] + ins v6.s[1], v2.s[2] + ins v7.s[0], v3.s[1] + ins v7.s[1], v3.s[2] + bfcvtn v25.4h, v6.4s + bfcvtn2 v25.8h, v7.4s + + // Group 2: row2[2], zeros + mov v27.16b, v31.16b + mov v28.16b, v31.16b + ins v27.s[0], v4.s[2] + ins v28.s[0], v5.s[2] + bfcvtn v26.4h, v27.4s + bfcvtn2 v26.8h, v28.4s + + // Initialize per-filter pointers. + mov x12, x2 // packed filter base + mov x13, x3 // output base (filter 0) + mov x14, x4 // bias base (unused here) + mov x11, x10 // filter count + +162: + add x16, x13, #MLAS_OUTPUT_BLOCK_BYTES + mov x17, x13 + mov x0, #4 + +163: + movi v16.4s, #0 + movi v17.4s, #0 + + // Load two channel pairs of packed filter weights (three groups each). + ldp q1, q4, [x12, #MLAS_BF16_GROUP_STRIDE_BYTES] + ldp q2, q5, [x12, #MLAS_BF16_GROUP2_OFFSET_BYTES] + ldp q0, q3, [x12], #MLAS_FILTER_CHANNEL_PAIR2_BYTES + + bfmmla v16.4s, v0.8h, v24.8h + bfmmla v16.4s, v1.8h, v25.8h + bfmmla v17.4s, v3.8h, v24.8h + bfmmla v16.4s, v2.8h, v26.8h + bfmmla v17.4s, v4.8h, v25.8h + bfmmla v17.4s, v5.8h, v26.8h + + fmax v16.4s, v16.4s, v31.4s + fmax v17.4s, v17.4s, v31.4s + + uzp1 v18.4s, v16.4s, v16.4s + uzp1 v19.4s, v17.4s, v17.4s + uzp2 v20.4s, v16.4s, v16.4s + uzp2 v21.4s, v17.4s, v17.4s + + st1 {v18.2s, v19.2s}, [x17], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + st1 {v20.2s, v21.2s}, [x16], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + + subs x0, x0, #1 + b.ne 163b + + add x13, x13, x9 + add x12, x12, #MLAS_BF16_FILTER_STRIDE_ADJUST_BYTES + subs x11, x11, #1 + b.ne 162b + + add x3, x3, #MLAS_OUTPUT_PAIR_BYTES + add x1, x1, x15 + + subs x7, x7, #1 + b.ne 161b + + b 199f + +// +// General accumulation path: accumulation and optional bias/ReLU. +// +180: +181: + // Compute the three BF16 input groups for two output positions. + mov x11, x1 // row0, output0 + add x12, x1, x6 // row1, output0 + add x13, x12, x6 // row2, output0 + add x14, x1, x5 // row0, output1 + add x0, x12, x5 // row1, output1 + + ld1 {v0.4s}, [x11] + ld1 {v1.4s}, [x14] + ld1 {v2.4s}, [x12] + ld1 {v3.4s}, [x0] + ld1 {v4.4s}, [x13] + + add x11, x13, x5 // row2, output1 + ld1 {v5.4s}, [x11] + + // Group 0: row0[0..2], row1[0] + ins v0.s[3], v2.s[0] + ins v1.s[3], v3.s[0] + bfcvtn v24.4h, v0.4s + bfcvtn2 v24.8h, v1.4s + + // Group 1: row1[1..2], row2[0..1] + ext v6.16b, v4.16b, v4.16b, #8 + ext v7.16b, v5.16b, v5.16b, #8 + ins v6.s[0], v2.s[1] + ins v6.s[1], v2.s[2] + ins v7.s[0], v3.s[1] + ins v7.s[1], v3.s[2] + bfcvtn v25.4h, v6.4s + bfcvtn2 v25.8h, v7.4s + + // Group 2: row2[2], zeros + mov v27.16b, v31.16b + mov v28.16b, v31.16b + ins v27.s[0], v4.s[2] + ins v28.s[0], v5.s[2] + bfcvtn v26.4h, v27.4s + bfcvtn2 v26.8h, v28.4s + + // Initialize per-filter pointers. + mov x12, x2 // packed filter base + mov x13, x3 // output base (filter 0) + mov x14, x4 // bias base + mov x11, x10 // filter count + +182: + add x16, x13, #MLAS_OUTPUT_BLOCK_BYTES + mov x17, x13 + mov x0, #4 + +183: + // Load existing output and interleave for BFMMLA accumulation. + ld1 {v18.4s}, [x17] + ld1 {v19.4s}, [x16] + zip1 v16.4s, v18.4s, v19.4s + zip2 v17.4s, v18.4s, v19.4s + + // Bias is shared across the two output positions. + tbz w8, #1, 184f + ldr q27, [x14], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + zip1 v28.4s, v27.4s, v27.4s + zip2 v29.4s, v27.4s, v27.4s + fadd v16.4s, v16.4s, v28.4s + fadd v17.4s, v17.4s, v29.4s +184: + + // Load two channel pairs of packed filter weights (three groups each). + ldp q1, q4, [x12, #MLAS_BF16_GROUP_STRIDE_BYTES] + ldp q2, q5, [x12, #MLAS_BF16_GROUP2_OFFSET_BYTES] + ldp q0, q3, [x12], #MLAS_FILTER_CHANNEL_PAIR2_BYTES + + bfmmla v16.4s, v0.8h, v24.8h + bfmmla v16.4s, v1.8h, v25.8h + bfmmla v17.4s, v3.8h, v24.8h + bfmmla v16.4s, v2.8h, v26.8h + bfmmla v17.4s, v4.8h, v25.8h + bfmmla v17.4s, v5.8h, v26.8h + + tbz w8, #2, 185f + fmax v16.4s, v16.4s, v31.4s + fmax v17.4s, v17.4s, v31.4s +185: + + uzp1 v18.4s, v16.4s, v16.4s + uzp1 v19.4s, v17.4s, v17.4s + uzp2 v20.4s, v16.4s, v16.4s + uzp2 v21.4s, v17.4s, v17.4s + + st1 {v18.2s, v19.2s}, [x17], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + st1 {v20.2s, v21.2s}, [x16], #MLAS_OUTPUT_CHANNEL_PAIR2_BYTES + + subs x0, x0, #1 + b.ne 183b + + add x13, x13, x9 + add x12, x12, #MLAS_BF16_FILTER_STRIDE_ADJUST_BYTES + subs x11, x11, #1 + b.ne 182b + + add x3, x3, #MLAS_OUTPUT_PAIR_BYTES + add x1, x1, x15 + + subs x7, x7, #1 + b.ne 181b + +199: + ret + +// +// void MLASCALL +// MlasConvNchwBf16PackFilterNeonAsm( +// const float* Filter, // x0 +// size_t FilterStrideElements, // x1 +// size_t FilterCount, // x2 +// uint16_t* PackedFilter) // x3 +// +// Pack the direct 3x3 FP32 filter blocks into the BF16 layout consumed by the +// direct BFMMLA kernel above. +// + +FUNCTION_ENTRY MlasConvNchwBf16PackFilterNeonAsm + cbz x2, 399f + + lsl x1, x1, #2 // filter stride bytes + movi v31.16b, #0 + +300: // FilterCount loop + mov x4, x0 // current filter block + channel offset + mov x5, x3 // packed group 0 + add x6, x3, #MLAS_BF16_GROUP_STRIDE_BYTES + add x7, x3, #MLAS_BF16_GROUP2_OFFSET_BYTES + mov x8, #8 // channel pairs + +301: // Channel-pair loop + ldr d0, [x4] + ldr d1, [x4, #64] + ldr d2, [x4, #128] + ldr d3, [x4, #192] + zip1 v16.4s, v0.4s, v1.4s + zip1 v17.4s, v2.4s, v3.4s + bfcvtn v18.4h, v16.4s + bfcvtn v19.4h, v17.4s + zip1 v20.4s, v18.4s, v19.4s + str q20, [x5], #16 + + ldr d0, [x4, #256] + ldr d1, [x4, #320] + ldr d2, [x4, #384] + ldr d3, [x4, #448] + zip1 v16.4s, v0.4s, v1.4s + zip1 v17.4s, v2.4s, v3.4s + bfcvtn v18.4h, v16.4s + bfcvtn v19.4h, v17.4s + zip1 v20.4s, v18.4s, v19.4s + str q20, [x6], #16 + + ldr d0, [x4, #512] + mov v16.16b, v31.16b + mov v17.16b, v31.16b + ins v16.s[0], v0.s[0] + ins v17.s[0], v0.s[1] + bfcvtn v20.4h, v16.4s + bfcvtn2 v20.8h, v17.4s + str q20, [x7], #16 + + add x4, x4, #8 + subs x8, x8, #1 + b.ne 301b + + add x0, x0, x1 + add x3, x3, #MLAS_BF16_FILTER_STRIDE_BYTES + subs x2, x2, #1 + b.ne 300b + +399: + ret + +// +// void MLASCALL +// MlasConvBf16OutputPostProcessNeonAsm( +// float* Output, // x0 +// size_t OutputCount, // x1 +// const float* Bias, // x2 +// unsigned KernelFlags) // w3 +// +// Apply the shared BF16 epilogue across a span of 16-float output rows. +// + +FUNCTION_ENTRY MlasConvBf16OutputPostProcessNeonAsm + cbz x1, 499f + + tst w3, #MLAS_CONV_FLAG_BIAS + b.eq 410f + cbnz x2, 410f + bic w3, w3, #MLAS_CONV_FLAG_BIAS +410: + tst w3, #(MLAS_CONV_FLAG_BIAS | MLAS_CONV_FLAG_RELU) + b.eq 499f + + movi v31.16b, #0 + + tst w3, #MLAS_CONV_FLAG_BIAS + b.eq 440f + + ldp q24, q25, [x2] + ldp q26, q27, [x2, #32] + + tst w3, #MLAS_CONV_FLAG_RELU + b.ne 460f + +420: // Bias only + ldp q0, q1, [x0] + ldp q2, q3, [x0, #32] + fadd v0.4s, v0.4s, v24.4s + fadd v1.4s, v1.4s, v25.4s + fadd v2.4s, v2.4s, v26.4s + fadd v3.4s, v3.4s, v27.4s + stp q0, q1, [x0] + stp q2, q3, [x0, #32] + add x0, x0, #MLAS_OUTPUT_BLOCK_BYTES + subs x1, x1, #1 + b.ne 420b + b 499f + +440: // ReLU only + ldp q0, q1, [x0] + ldp q2, q3, [x0, #32] + fmax v0.4s, v0.4s, v31.4s + fmax v1.4s, v1.4s, v31.4s + fmax v2.4s, v2.4s, v31.4s + fmax v3.4s, v3.4s, v31.4s + stp q0, q1, [x0] + stp q2, q3, [x0, #32] + add x0, x0, #MLAS_OUTPUT_BLOCK_BYTES + subs x1, x1, #1 + b.ne 440b + b 499f + +460: // Bias + ReLU + ldp q0, q1, [x0] + ldp q2, q3, [x0, #32] + fadd v0.4s, v0.4s, v24.4s + fadd v1.4s, v1.4s, v25.4s + fadd v2.4s, v2.4s, v26.4s + fadd v3.4s, v3.4s, v27.4s + fmax v0.4s, v0.4s, v31.4s + fmax v1.4s, v1.4s, v31.4s + fmax v2.4s, v2.4s, v31.4s + fmax v3.4s, v3.4s, v31.4s + stp q0, q1, [x0] + stp q2, q3, [x0, #32] + add x0, x0, #MLAS_OUTPUT_BLOCK_BYTES + subs x1, x1, #1 + b.ne 460b + +499: + ret diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeonBf16.S b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeonBf16.S new file mode 100644 index 0000000000000..1325c3f5fd3e1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeonBf16.S @@ -0,0 +1,1098 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + + +Module Name: + + SconvPointwiseKernelNeonBf16.S + +Abstract: + + This module implements the pointwise (1x1) NCHWc convolution on AArch64 + using BF16 matrix multiply accumulate instructions. + + The kernel consumes FP32 input and filter operands, but casts them to BF16 + before the multiply while accumulating in FP32 using BFMMLA. + +--*/ + +#include "asmmacro.h" + +#if defined(__aarch64__) + + .text + .arch_extension bf16 + +// +// void MlasConvPointwiseBf16PackFilterNeonAsm( +// const float* Filter, // x0 +// size_t InputChannels, // x1 +// uint16_t* PackedFilter) // x2 +// +// Pack one 16x16 FP32 filter tile per input-channel block into the BFMMLA- +// friendly BF16 layout consumed by the pointwise kernels below. +// + + FUNCTION_ENTRY MlasConvPointwiseBf16PackFilterNeonAsm + + cbz x1, 19f + +10: // InputChannels loop + mov x3, x0 // current 16x16 FP32 filter tile + mov x4, x2 // current packed BF16 tile + mov x5, #4 // four k-groups per tile + +11: // K-group loop + ldp q24, q25, [x3, #0] + ldp q26, q27, [x3, #32] + bfcvtn v0.4h, v24.4s + bfcvtn2 v0.8h, v25.4s + bfcvtn v1.4h, v26.4s + bfcvtn2 v1.8h, v27.4s + + ldp q24, q25, [x3, #64] + ldp q26, q27, [x3, #96] + bfcvtn v2.4h, v24.4s + bfcvtn2 v2.8h, v25.4s + bfcvtn v3.4h, v26.4s + bfcvtn2 v3.8h, v27.4s + + ldp q24, q25, [x3, #128] + ldp q26, q27, [x3, #160] + bfcvtn v4.4h, v24.4s + bfcvtn2 v4.8h, v25.4s + bfcvtn v5.4h, v26.4s + bfcvtn2 v5.8h, v27.4s + + ldp q24, q25, [x3, #192] + ldp q26, q27, [x3, #224] + bfcvtn v6.4h, v24.4s + bfcvtn2 v6.8h, v25.4s + bfcvtn v7.4h, v26.4s + bfcvtn2 v7.8h, v27.4s + + zip1 v16.8h, v0.8h, v2.8h + zip2 v17.8h, v0.8h, v2.8h + zip1 v18.8h, v4.8h, v6.8h + zip2 v19.8h, v4.8h, v6.8h + zip1 v20.8h, v1.8h, v3.8h + zip2 v21.8h, v1.8h, v3.8h + zip1 v22.8h, v5.8h, v7.8h + zip2 v23.8h, v5.8h, v7.8h + + zip1 v24.4s, v16.4s, v18.4s + zip2 v25.4s, v16.4s, v18.4s + zip1 v26.4s, v17.4s, v19.4s + zip2 v27.4s, v17.4s, v19.4s + zip1 v28.4s, v20.4s, v22.4s + zip2 v29.4s, v20.4s, v22.4s + zip1 v30.4s, v21.4s, v23.4s + zip2 v31.4s, v21.4s, v23.4s + + stp q24, q25, [x4], #32 + stp q26, q27, [x4], #32 + stp q28, q29, [x4], #32 + stp q30, q31, [x4], #32 + + add x3, x3, #256 + subs x5, x5, #1 + b.ne 11b + + add x0, x0, #1024 + add x2, x2, #512 + subs x1, x1, #1 + b.ne 10b + +19: + ret + +// +// void MlasConvPointwiseBf16PackedInputKernelNeon4xAsm( +// const uint16_t* PackedInput, // x0 +// const uint16_t* PackedFilter, // x1 +// float* Output, // x2 +// size_t OutputQuartetCount, // x3 +// size_t InputChannels) // x4 +// +// PackedInput layout per output quartet: +// input-channel major. Each input channel block contributes eight 8xBF16 +// records: k-groups 0-3 for rows 0-1 followed by k-groups 0-3 for rows 2-3. +// This preserves the original 4-output accumulator shape while avoiding +// repeated FP32-to-BF16 conversion for every filter tile. +// +// Register ownership: +// x5 - remaining output quartets +// x6 - packed input base for current output quartet +// x7 - output base for current output quartet +// x10 - packed input pointer within ic loop / store base +// x11 - store pointer for row 1 +// x12 - packed filter pointer within ic loop +// x13 - remaining input channels +// x14/x15 - store pointers for rows 2-3 +// +// v16-v23 - accumulators for output rows 0-1 +// v24-v31 - accumulators for output rows 2-3 +// v8-v11 - staged packed BF16 inputs for rows 0-1 +// v12-v15 - staged packed BF16 inputs for rows 2-3 +// v0-v7 - filter loads and store staging +// + + FUNCTION_ENTRY MlasConvPointwiseBf16PackedInputKernelNeon4xAsm + + cbz x3, 139f + + sub sp, sp, #128 + stp q8, q9, [sp, #0] + stp q10, q11, [sp, #32] + stp q12, q13, [sp, #64] + stp q14, q15, [sp, #96] + + mov x5, x3 // remaining output quartets + mov x6, x0 // packed input base for current output quartet + mov x7, x2 // output base for current output quartet + +120: // OutputQuartetCount loop + eor v16.16b, v16.16b, v16.16b + eor v17.16b, v17.16b, v17.16b + eor v18.16b, v18.16b, v18.16b + eor v19.16b, v19.16b, v19.16b + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + eor v24.16b, v24.16b, v24.16b + eor v25.16b, v25.16b, v25.16b + eor v26.16b, v26.16b, v26.16b + eor v27.16b, v27.16b, v27.16b + eor v28.16b, v28.16b, v28.16b + eor v29.16b, v29.16b, v29.16b + eor v30.16b, v30.16b, v30.16b + eor v31.16b, v31.16b, v31.16b + + mov x10, x6 // packed input pointer + mov x12, x1 // packed filter pointer + mov x13, x4 // remaining input channels + +121: // InputChannels loop + prfm pldl1keep, [x10, #256] + prfm pldl1keep, [x12, #512] + + ldp q8, q9, [x10], #32 + ldp q10, q11, [x10], #32 + ldp q12, q13, [x10], #32 + ldp q14, q15, [x10], #32 + + ldp q0, q1, [x12, #0] + ldp q2, q3, [x12, #32] + ldp q4, q5, [x12, #64] + ldp q6, q7, [x12, #96] + + bfmmla v16.4s, v8.8h, v0.8h + bfmmla v17.4s, v8.8h, v1.8h + bfmmla v18.4s, v8.8h, v2.8h + bfmmla v19.4s, v8.8h, v3.8h + bfmmla v24.4s, v12.8h, v0.8h + bfmmla v25.4s, v12.8h, v1.8h + bfmmla v26.4s, v12.8h, v2.8h + bfmmla v27.4s, v12.8h, v3.8h + + ldp q0, q1, [x12, #128] + ldp q2, q3, [x12, #160] + + bfmmla v20.4s, v8.8h, v4.8h + bfmmla v21.4s, v8.8h, v5.8h + bfmmla v22.4s, v8.8h, v6.8h + bfmmla v23.4s, v8.8h, v7.8h + bfmmla v28.4s, v12.8h, v4.8h + bfmmla v29.4s, v12.8h, v5.8h + bfmmla v30.4s, v12.8h, v6.8h + bfmmla v31.4s, v12.8h, v7.8h + + ldp q4, q5, [x12, #192] + ldp q6, q7, [x12, #224] + + bfmmla v16.4s, v9.8h, v0.8h + bfmmla v17.4s, v9.8h, v1.8h + bfmmla v18.4s, v9.8h, v2.8h + bfmmla v19.4s, v9.8h, v3.8h + bfmmla v24.4s, v13.8h, v0.8h + bfmmla v25.4s, v13.8h, v1.8h + bfmmla v26.4s, v13.8h, v2.8h + bfmmla v27.4s, v13.8h, v3.8h + + ldp q0, q1, [x12, #256] + ldp q2, q3, [x12, #288] + + bfmmla v20.4s, v9.8h, v4.8h + bfmmla v21.4s, v9.8h, v5.8h + bfmmla v22.4s, v9.8h, v6.8h + bfmmla v23.4s, v9.8h, v7.8h + bfmmla v28.4s, v13.8h, v4.8h + bfmmla v29.4s, v13.8h, v5.8h + bfmmla v30.4s, v13.8h, v6.8h + bfmmla v31.4s, v13.8h, v7.8h + + ldp q4, q5, [x12, #320] + ldp q6, q7, [x12, #352] + + bfmmla v16.4s, v10.8h, v0.8h + bfmmla v17.4s, v10.8h, v1.8h + bfmmla v18.4s, v10.8h, v2.8h + bfmmla v19.4s, v10.8h, v3.8h + bfmmla v24.4s, v14.8h, v0.8h + bfmmla v25.4s, v14.8h, v1.8h + bfmmla v26.4s, v14.8h, v2.8h + bfmmla v27.4s, v14.8h, v3.8h + + ldp q0, q1, [x12, #384] + ldp q2, q3, [x12, #416] + + bfmmla v20.4s, v10.8h, v4.8h + bfmmla v21.4s, v10.8h, v5.8h + bfmmla v22.4s, v10.8h, v6.8h + bfmmla v23.4s, v10.8h, v7.8h + bfmmla v28.4s, v14.8h, v4.8h + bfmmla v29.4s, v14.8h, v5.8h + bfmmla v30.4s, v14.8h, v6.8h + bfmmla v31.4s, v14.8h, v7.8h + + ldp q4, q5, [x12, #448] + ldp q6, q7, [x12, #480] + + bfmmla v16.4s, v11.8h, v0.8h + bfmmla v17.4s, v11.8h, v1.8h + bfmmla v18.4s, v11.8h, v2.8h + bfmmla v19.4s, v11.8h, v3.8h + bfmmla v24.4s, v15.8h, v0.8h + bfmmla v25.4s, v15.8h, v1.8h + bfmmla v26.4s, v15.8h, v2.8h + bfmmla v27.4s, v15.8h, v3.8h + bfmmla v20.4s, v11.8h, v4.8h + bfmmla v21.4s, v11.8h, v5.8h + bfmmla v22.4s, v11.8h, v6.8h + bfmmla v23.4s, v11.8h, v7.8h + bfmmla v28.4s, v15.8h, v4.8h + bfmmla v29.4s, v15.8h, v5.8h + bfmmla v30.4s, v15.8h, v6.8h + bfmmla v31.4s, v15.8h, v7.8h + + add x12, x12, #512 + subs x13, x13, #1 + b.ne 121b + + mov x6, x10 + mov x10, x7 + add x11, x10, #64 + + zip1 v0.2d, v16.2d, v17.2d + zip1 v1.2d, v18.2d, v19.2d + zip1 v2.2d, v20.2d, v21.2d + zip1 v3.2d, v22.2d, v23.2d + zip2 v4.2d, v16.2d, v17.2d + zip2 v5.2d, v18.2d, v19.2d + zip2 v6.2d, v20.2d, v21.2d + zip2 v7.2d, v22.2d, v23.2d + + stp q0, q1, [x10], #32 + stp q2, q3, [x10], #32 + stp q4, q5, [x11], #32 + stp q6, q7, [x11], #32 + + add x14, x7, #128 + add x15, x7, #192 + + zip1 v0.2d, v24.2d, v25.2d + zip1 v1.2d, v26.2d, v27.2d + zip1 v2.2d, v28.2d, v29.2d + zip1 v3.2d, v30.2d, v31.2d + zip2 v4.2d, v24.2d, v25.2d + zip2 v5.2d, v26.2d, v27.2d + zip2 v6.2d, v28.2d, v29.2d + zip2 v7.2d, v30.2d, v31.2d + + stp q0, q1, [x14], #32 + stp q2, q3, [x14], #32 + stp q4, q5, [x15], #32 + stp q6, q7, [x15], #32 + + add x7, x7, #256 + subs x5, x5, #1 + b.ne 120b + + ldp q8, q9, [sp, #0] + ldp q10, q11, [sp, #32] + ldp q12, q13, [sp, #64] + ldp q14, q15, [sp, #96] + add sp, sp, #128 +139: + ret + +// +// void MlasConvPointwiseBf16PackedInputKernelNeon2xAsm( +// const uint16_t* PackedInput, // x0 +// const uint16_t* PackedFilter, // x1 +// float* Output, // x2 +// size_t OutputPairCount, // x3 +// size_t InputChannels) // x4 +// +// PackedInput layout per output pair: +// input-channel major. Each input channel block contributes four 8xBF16 +// records (k-groups 0-3). Each record stores the low row in lanes 0-3 and +// the high row in lanes 4-7 so that the existing two-output BFMMLA flow can +// be reused without repeating FP32-to-BF16 conversion for every filter tile. +// +// Register ownership: +// x5 - remaining output pairs +// x6 - packed input base for current output pair +// x7 - output base for current output pair +// x10 - packed input pointer within ic loop +// x11 - packed filter pointer within ic loop +// x12 - remaining input channels +// x14/x15 - store pointers +// +// v16-v23 - accumulators for the 8 column pairs +// v8-v11 - staged packed BF16 inputs for k-groups 0-3 +// v0-v7 - filter loads and store staging +// + + FUNCTION_ENTRY MlasConvPointwiseBf16PackedInputKernelNeon2xAsm + + cbz x3, 159f + + sub sp, sp, #64 + stp q8, q9, [sp, #0] + stp q10, q11, [sp, #32] + + mov x5, x3 // remaining output pairs + mov x6, x0 // packed input base for current output pair + mov x7, x2 // output base for current output pair + +140: // OutputPairCount loop + eor v16.16b, v16.16b, v16.16b + eor v17.16b, v17.16b, v17.16b + eor v18.16b, v18.16b, v18.16b + eor v19.16b, v19.16b, v19.16b + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + + mov x10, x6 // packed input pointer + mov x11, x1 // packed filter pointer + mov x12, x4 // remaining input channels + +141: // InputChannels loop + prfm pldl1keep, [x10, #256] + prfm pldl1keep, [x11, #512] + + ldp q8, q9, [x10], #32 + ldp q10, q11, [x10], #32 + + ldp q0, q1, [x11, #0] + ldp q2, q3, [x11, #32] + ldp q4, q5, [x11, #64] + ldp q6, q7, [x11, #96] + + bfmmla v16.4s, v8.8h, v0.8h + bfmmla v17.4s, v8.8h, v1.8h + bfmmla v18.4s, v8.8h, v2.8h + bfmmla v19.4s, v8.8h, v3.8h + + ldp q0, q1, [x11, #128] + ldp q2, q3, [x11, #160] + + bfmmla v20.4s, v8.8h, v4.8h + bfmmla v21.4s, v8.8h, v5.8h + bfmmla v22.4s, v8.8h, v6.8h + bfmmla v23.4s, v8.8h, v7.8h + + ldp q4, q5, [x11, #192] + ldp q6, q7, [x11, #224] + + bfmmla v16.4s, v9.8h, v0.8h + bfmmla v17.4s, v9.8h, v1.8h + bfmmla v18.4s, v9.8h, v2.8h + bfmmla v19.4s, v9.8h, v3.8h + + ldp q0, q1, [x11, #256] + ldp q2, q3, [x11, #288] + + bfmmla v20.4s, v9.8h, v4.8h + bfmmla v21.4s, v9.8h, v5.8h + bfmmla v22.4s, v9.8h, v6.8h + bfmmla v23.4s, v9.8h, v7.8h + + ldp q4, q5, [x11, #320] + ldp q6, q7, [x11, #352] + + bfmmla v16.4s, v10.8h, v0.8h + bfmmla v17.4s, v10.8h, v1.8h + bfmmla v18.4s, v10.8h, v2.8h + bfmmla v19.4s, v10.8h, v3.8h + + ldp q0, q1, [x11, #384] + ldp q2, q3, [x11, #416] + + bfmmla v20.4s, v10.8h, v4.8h + bfmmla v21.4s, v10.8h, v5.8h + bfmmla v22.4s, v10.8h, v6.8h + bfmmla v23.4s, v10.8h, v7.8h + + ldp q4, q5, [x11, #448] + ldp q6, q7, [x11, #480] + + bfmmla v16.4s, v11.8h, v0.8h + bfmmla v17.4s, v11.8h, v1.8h + bfmmla v18.4s, v11.8h, v2.8h + bfmmla v19.4s, v11.8h, v3.8h + bfmmla v20.4s, v11.8h, v4.8h + bfmmla v21.4s, v11.8h, v5.8h + bfmmla v22.4s, v11.8h, v6.8h + bfmmla v23.4s, v11.8h, v7.8h + + add x11, x11, #512 + subs x12, x12, #1 + b.ne 141b + + mov x6, x10 + mov x14, x7 + add x15, x14, #64 + + zip1 v0.2d, v16.2d, v17.2d + zip1 v1.2d, v18.2d, v19.2d + zip1 v2.2d, v20.2d, v21.2d + zip1 v3.2d, v22.2d, v23.2d + zip2 v4.2d, v16.2d, v17.2d + zip2 v5.2d, v18.2d, v19.2d + zip2 v6.2d, v20.2d, v21.2d + zip2 v7.2d, v22.2d, v23.2d + + stp q0, q1, [x14], #32 + stp q2, q3, [x14], #32 + stp q4, q5, [x15], #32 + stp q6, q7, [x15], #32 + + add x7, x7, #128 + subs x5, x5, #1 + b.ne 140b + + ldp q8, q9, [sp, #0] + ldp q10, q11, [sp, #32] + add sp, sp, #64 +159: + ret + +// +// void MlasConvPointwiseBf16KernelNeonAsm( +// const float* Input, // x0 +// const uint16_t* PackedFilter, // x1 +// float* Output, // x2 +// size_t StrideWidthBytes, // x3 +// size_t InputStrideBytes, // x4 +// size_t InputChannels, // x5 +// size_t OutputCountEven) // x6 +// +// PackedFilter layout per input channel block: +// k-group major (4 groups of 4 BF16 lanes) and, within each group, +// column-pair major (8 pairs of output columns). +// Each (k-group, column-pair) record is 8 BF16 values laid out as: +// [b(k+0,c0), b(k+1,c0), b(k+2,c0), b(k+3,c0), +// b(k+0,c1), b(k+1,c1), b(k+2,c1), b(k+3,c1)] +// +// Register ownership: +// x7 - remaining output count (even) +// x8 - input base pointer for current output index +// x9 - output base pointer for current output index +// x10 - input pointer (output row 0) within ic loop +// x11 - input pointer (output row 1) within ic loop +// x12 - packed filter pointer within ic loop +// x13 - remaining input channels +// x14 - input pointer (output row 2) within ic loop / scratch +// x15 - input pointer (output row 3) within ic loop / scratch +// x16 - StrideWidthBytes * 2 +// x17 - InputStrideBytes - (BlockSize * sizeof(float)) +// x6 - StrideWidthBytes * 4 +// +// v16-v23 - accumulators for output rows 0-1 +// v24-v31 - accumulators for output rows 2-3 +// v8-v11 - staged BF16 inputs for rows 0-1 (k-groups 0-3) +// v12-v15 - staged BF16 inputs for rows 2-3 (k-groups 0-3) +// v0-v3 - filter bank A / temporaries +// v4-v7 - filter bank B / temporaries and store staging +// + + FUNCTION_ENTRY MlasConvPointwiseBf16KernelNeonAsm + + cbz x6, 98f + + // The optimized hot path uses v8-v15 to stage all four k-groups of + // converted BF16 input for a channel block. Preserve the callee-saved + // registers once per call to keep the inner loop free of spills. + sub sp, sp, #128 + stp q8, q9, [sp, #0] + stp q10, q11, [sp, #32] + stp q12, q13, [sp, #64] + stp q14, q15, [sp, #96] + + mov x7, x6 // remaining outputs + mov x8, x0 // input base for current output index + mov x9, x2 // output base for current output index + + add x16, x3, x3 // StrideWidthBytes * 2 + sub x17, x4, #64 // InputStrideBytes - 64 + + cmp x7, #4 + b.lo 90f + + add x6, x16, x16 // StrideWidthBytes * 4 + +80: // OutputCountEven loop (process 4 outputs per iteration) + // Zero accumulators for rows 0-1. + eor v16.16b, v16.16b, v16.16b + eor v17.16b, v17.16b, v17.16b + eor v18.16b, v18.16b, v18.16b + eor v19.16b, v19.16b, v19.16b + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + + // Zero accumulators for rows 2-3. + eor v24.16b, v24.16b, v24.16b + eor v25.16b, v25.16b, v25.16b + eor v26.16b, v26.16b, v26.16b + eor v27.16b, v27.16b, v27.16b + eor v28.16b, v28.16b, v28.16b + eor v29.16b, v29.16b, v29.16b + eor v30.16b, v30.16b, v30.16b + eor v31.16b, v31.16b, v31.16b + + // Initialize per-iteration pointers. + mov x10, x8 // input row 0 base + add x11, x10, x3 // input row 1 base + add x14, x10, x16 // input row 2 base + add x15, x11, x16 // input row 3 base + mov x12, x1 // packed filter base + mov x13, x5 // remaining input channel blocks + +81: // InputChannels loop (4 outputs) + // Prefetch the next input/filter panels to keep the load pipelines fed. + prfm pldl1keep, [x10, #256] + prfm pldl1keep, [x12, #512] + + // Convert all four k-groups up front into v8-v15. Use LDP to fetch two + // k-groups per row at a time, cutting the number of load instructions in + // half without changing the dataflow seen by the BFMMLA core. + // k-groups 0-1 input conversion (rows 0-1). + ldp q0, q1, [x10], #32 + ldp q2, q3, [x11], #32 + bfcvtn v8.4h, v0.4s + bfcvtn2 v8.8h, v2.4s + bfcvtn v9.4h, v1.4s + bfcvtn2 v9.8h, v3.4s + + // k-groups 2-3 input conversion (rows 0-1). + ldp q0, q1, [x10], #32 + ldp q2, q3, [x11], #32 + bfcvtn v10.4h, v0.4s + bfcvtn2 v10.8h, v2.4s + bfcvtn v11.4h, v1.4s + bfcvtn2 v11.8h, v3.4s + + // k-groups 0-1 input conversion (rows 2-3). + ldp q0, q1, [x14], #32 + ldp q2, q3, [x15], #32 + bfcvtn v12.4h, v0.4s + bfcvtn2 v12.8h, v2.4s + bfcvtn v13.4h, v1.4s + bfcvtn2 v13.8h, v3.4s + + // k-groups 2-3 input conversion (rows 2-3). + ldp q0, q1, [x14], #32 + ldp q2, q3, [x15], #32 + bfcvtn v14.4h, v0.4s + bfcvtn2 v14.8h, v2.4s + bfcvtn v15.4h, v1.4s + bfcvtn2 v15.8h, v3.4s + + // Advance to the next input channel block early; the pointers won't be + // reused until the next iteration. + add x10, x10, x17 + add x11, x11, x17 + add x14, x14, x17 + add x15, x15, x17 + + // Preload k-group 0 filter panels into two register banks. This enables + // loading the next k-group's column pairs 0-3 while column pairs 4-7 are + // still accumulating, hiding most of the load-to-use latency. + ldp q0, q1, [x12, #0] + ldp q2, q3, [x12, #32] + ldp q4, q5, [x12, #64] + ldp q6, q7, [x12, #96] + + // k-group 0, column pairs 0-3. + bfmmla v16.4s, v8.8h, v0.8h + bfmmla v17.4s, v8.8h, v1.8h + bfmmla v18.4s, v8.8h, v2.8h + bfmmla v19.4s, v8.8h, v3.8h + bfmmla v24.4s, v12.8h, v0.8h + bfmmla v25.4s, v12.8h, v1.8h + bfmmla v26.4s, v12.8h, v2.8h + bfmmla v27.4s, v12.8h, v3.8h + + // Preload k-group 1, column pairs 0-3. + ldp q0, q1, [x12, #128] + ldp q2, q3, [x12, #160] + + // k-group 0, column pairs 4-7. + bfmmla v20.4s, v8.8h, v4.8h + bfmmla v21.4s, v8.8h, v5.8h + bfmmla v22.4s, v8.8h, v6.8h + bfmmla v23.4s, v8.8h, v7.8h + bfmmla v28.4s, v12.8h, v4.8h + bfmmla v29.4s, v12.8h, v5.8h + bfmmla v30.4s, v12.8h, v6.8h + bfmmla v31.4s, v12.8h, v7.8h + + // Preload k-group 1, column pairs 4-7. + ldp q4, q5, [x12, #192] + ldp q6, q7, [x12, #224] + + // k-group 1, column pairs 0-3. + bfmmla v16.4s, v9.8h, v0.8h + bfmmla v17.4s, v9.8h, v1.8h + bfmmla v18.4s, v9.8h, v2.8h + bfmmla v19.4s, v9.8h, v3.8h + bfmmla v24.4s, v13.8h, v0.8h + bfmmla v25.4s, v13.8h, v1.8h + bfmmla v26.4s, v13.8h, v2.8h + bfmmla v27.4s, v13.8h, v3.8h + + // Preload k-group 2, column pairs 0-3. + ldp q0, q1, [x12, #256] + ldp q2, q3, [x12, #288] + + // k-group 1, column pairs 4-7. + bfmmla v20.4s, v9.8h, v4.8h + bfmmla v21.4s, v9.8h, v5.8h + bfmmla v22.4s, v9.8h, v6.8h + bfmmla v23.4s, v9.8h, v7.8h + bfmmla v28.4s, v13.8h, v4.8h + bfmmla v29.4s, v13.8h, v5.8h + bfmmla v30.4s, v13.8h, v6.8h + bfmmla v31.4s, v13.8h, v7.8h + + // Preload k-group 2, column pairs 4-7. + ldp q4, q5, [x12, #320] + ldp q6, q7, [x12, #352] + + // k-group 2, column pairs 0-3. + bfmmla v16.4s, v10.8h, v0.8h + bfmmla v17.4s, v10.8h, v1.8h + bfmmla v18.4s, v10.8h, v2.8h + bfmmla v19.4s, v10.8h, v3.8h + bfmmla v24.4s, v14.8h, v0.8h + bfmmla v25.4s, v14.8h, v1.8h + bfmmla v26.4s, v14.8h, v2.8h + bfmmla v27.4s, v14.8h, v3.8h + + // Preload k-group 3, column pairs 0-3. + ldp q0, q1, [x12, #384] + ldp q2, q3, [x12, #416] + + // k-group 2, column pairs 4-7. + bfmmla v20.4s, v10.8h, v4.8h + bfmmla v21.4s, v10.8h, v5.8h + bfmmla v22.4s, v10.8h, v6.8h + bfmmla v23.4s, v10.8h, v7.8h + bfmmla v28.4s, v14.8h, v4.8h + bfmmla v29.4s, v14.8h, v5.8h + bfmmla v30.4s, v14.8h, v6.8h + bfmmla v31.4s, v14.8h, v7.8h + + // Preload k-group 3, column pairs 4-7. + ldp q4, q5, [x12, #448] + ldp q6, q7, [x12, #480] + + // k-group 3, column pairs 0-3. + bfmmla v16.4s, v11.8h, v0.8h + bfmmla v17.4s, v11.8h, v1.8h + bfmmla v18.4s, v11.8h, v2.8h + bfmmla v19.4s, v11.8h, v3.8h + bfmmla v24.4s, v15.8h, v0.8h + bfmmla v25.4s, v15.8h, v1.8h + bfmmla v26.4s, v15.8h, v2.8h + bfmmla v27.4s, v15.8h, v3.8h + + // k-group 3, column pairs 4-7. + bfmmla v20.4s, v11.8h, v4.8h + bfmmla v21.4s, v11.8h, v5.8h + bfmmla v22.4s, v11.8h, v6.8h + bfmmla v23.4s, v11.8h, v7.8h + bfmmla v28.4s, v15.8h, v4.8h + bfmmla v29.4s, v15.8h, v5.8h + bfmmla v30.4s, v15.8h, v6.8h + bfmmla v31.4s, v15.8h, v7.8h + + add x12, x12, #512 + + subs x13, x13, #1 + b.ne 81b + + // Store rows 0-1: de-interleave via 64-bit ZIP. + mov x10, x9 + add x11, x10, #64 + + zip1 v0.2d, v16.2d, v17.2d + zip1 v1.2d, v18.2d, v19.2d + zip1 v2.2d, v20.2d, v21.2d + zip1 v3.2d, v22.2d, v23.2d + zip2 v4.2d, v16.2d, v17.2d + zip2 v5.2d, v18.2d, v19.2d + zip2 v6.2d, v20.2d, v21.2d + zip2 v7.2d, v22.2d, v23.2d + + stp q0, q1, [x10], #32 + stp q2, q3, [x10], #32 + stp q4, q5, [x11], #32 + stp q6, q7, [x11], #32 + + // Store rows 2-3. + add x14, x9, #128 + add x15, x9, #192 + + zip1 v0.2d, v24.2d, v25.2d + zip1 v1.2d, v26.2d, v27.2d + zip1 v2.2d, v28.2d, v29.2d + zip1 v3.2d, v30.2d, v31.2d + zip2 v4.2d, v24.2d, v25.2d + zip2 v5.2d, v26.2d, v27.2d + zip2 v6.2d, v28.2d, v29.2d + zip2 v7.2d, v30.2d, v31.2d + + stp q0, q1, [x14], #32 + stp q2, q3, [x14], #32 + stp q4, q5, [x15], #32 + stp q6, q7, [x15], #32 + + // Advance to the next block of outputs. + subs x7, x7, #4 + add x8, x8, x6 + add x9, x9, #256 + cmp x7, #4 + b.hs 80b + +90: // OutputCountEven loop (process 2 outputs per iteration) + cbz x7, 99f + + // Zero accumulators. + eor v16.16b, v16.16b, v16.16b + eor v17.16b, v17.16b, v17.16b + eor v18.16b, v18.16b, v18.16b + eor v19.16b, v19.16b, v19.16b + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + + // Initialize per-iteration pointers. + mov x10, x8 // input row 0 base + add x11, x10, x3 // input row 1 base + mov x12, x1 // packed filter base + mov x13, x5 // remaining input channel blocks + +91: // InputChannels loop (2 outputs) + // Prefetch the next panels. + prfm pldl1keep, [x10, #256] + prfm pldl1keep, [x12, #512] + + // Convert all four k-groups into v8-v11 using paired loads. + ldp q0, q1, [x10], #32 + ldp q2, q3, [x11], #32 + bfcvtn v8.4h, v0.4s + bfcvtn2 v8.8h, v2.4s + bfcvtn v9.4h, v1.4s + bfcvtn2 v9.8h, v3.4s + + ldp q0, q1, [x10], #32 + ldp q2, q3, [x11], #32 + bfcvtn v10.4h, v0.4s + bfcvtn2 v10.8h, v2.4s + bfcvtn v11.4h, v1.4s + bfcvtn2 v11.8h, v3.4s + + // Advance to the next input channel block early. + add x10, x10, x17 + add x11, x11, x17 + + // Preload k-group 0 filter panels into two banks. + ldp q0, q1, [x12, #0] + ldp q2, q3, [x12, #32] + ldp q4, q5, [x12, #64] + ldp q6, q7, [x12, #96] + + // k-group 0, column pairs 0-3. + bfmmla v16.4s, v8.8h, v0.8h + bfmmla v17.4s, v8.8h, v1.8h + bfmmla v18.4s, v8.8h, v2.8h + bfmmla v19.4s, v8.8h, v3.8h + + // Preload k-group 1, column pairs 0-3. + ldp q0, q1, [x12, #128] + ldp q2, q3, [x12, #160] + + // k-group 0, column pairs 4-7. + bfmmla v20.4s, v8.8h, v4.8h + bfmmla v21.4s, v8.8h, v5.8h + bfmmla v22.4s, v8.8h, v6.8h + bfmmla v23.4s, v8.8h, v7.8h + + // Preload k-group 1, column pairs 4-7. + ldp q4, q5, [x12, #192] + ldp q6, q7, [x12, #224] + + // k-group 1, column pairs 0-3. + bfmmla v16.4s, v9.8h, v0.8h + bfmmla v17.4s, v9.8h, v1.8h + bfmmla v18.4s, v9.8h, v2.8h + bfmmla v19.4s, v9.8h, v3.8h + + // Preload k-group 2, column pairs 0-3. + ldp q0, q1, [x12, #256] + ldp q2, q3, [x12, #288] + + // k-group 1, column pairs 4-7. + bfmmla v20.4s, v9.8h, v4.8h + bfmmla v21.4s, v9.8h, v5.8h + bfmmla v22.4s, v9.8h, v6.8h + bfmmla v23.4s, v9.8h, v7.8h + + // Preload k-group 2, column pairs 4-7. + ldp q4, q5, [x12, #320] + ldp q6, q7, [x12, #352] + + // k-group 2, column pairs 0-3. + bfmmla v16.4s, v10.8h, v0.8h + bfmmla v17.4s, v10.8h, v1.8h + bfmmla v18.4s, v10.8h, v2.8h + bfmmla v19.4s, v10.8h, v3.8h + + // Preload k-group 3, column pairs 0-3. + ldp q0, q1, [x12, #384] + ldp q2, q3, [x12, #416] + + // k-group 2, column pairs 4-7. + bfmmla v20.4s, v10.8h, v4.8h + bfmmla v21.4s, v10.8h, v5.8h + bfmmla v22.4s, v10.8h, v6.8h + bfmmla v23.4s, v10.8h, v7.8h + + // Preload k-group 3, column pairs 4-7. + ldp q4, q5, [x12, #448] + ldp q6, q7, [x12, #480] + + // k-group 3, column pairs 0-3. + bfmmla v16.4s, v11.8h, v0.8h + bfmmla v17.4s, v11.8h, v1.8h + bfmmla v18.4s, v11.8h, v2.8h + bfmmla v19.4s, v11.8h, v3.8h + + // k-group 3, column pairs 4-7. + bfmmla v20.4s, v11.8h, v4.8h + bfmmla v21.4s, v11.8h, v5.8h + bfmmla v22.4s, v11.8h, v6.8h + bfmmla v23.4s, v11.8h, v7.8h + + add x12, x12, #512 + + // Advance to the next input channel block. + subs x13, x13, #1 + b.ne 91b + + // Store: de-interleave rows via 64-bit ZIP instructions. + mov x14, x9 + add x15, x14, #64 + + zip1 v0.2d, v16.2d, v17.2d + zip1 v1.2d, v18.2d, v19.2d + zip1 v2.2d, v20.2d, v21.2d + zip1 v3.2d, v22.2d, v23.2d + zip2 v4.2d, v16.2d, v17.2d + zip2 v5.2d, v18.2d, v19.2d + zip2 v6.2d, v20.2d, v21.2d + zip2 v7.2d, v22.2d, v23.2d + + stp q0, q1, [x14], #32 + stp q2, q3, [x14], #32 + stp q4, q5, [x15], #32 + stp q6, q7, [x15], #32 + + // Advance to the next pair of outputs. + subs x7, x7, #2 + add x8, x8, x16 + add x9, x9, #128 + b.ne 90b + +99: + ldp q8, q9, [sp, #0] + ldp q10, q11, [sp, #32] + ldp q12, q13, [sp, #64] + ldp q14, q15, [sp, #96] + add sp, sp, #128 +98: + ret + +// +// Single-output tail kernel. +// +// void MlasConvPointwiseBf16KernelNeonSingleOutputAsm( +// const float* Input, // x0 +// const uint16_t* PackedFilter, // x1 +// float* Output, // x2 +// size_t StrideWidthBytes, // x3 +// size_t InputStrideBytes, // x4 +// size_t InputChannels, // x5 +// size_t OutputCount) // x6 +// +// This kernel processes one output at a time and is used for odd-width +// remainders (or narrow outputs such as ow=1). It still uses BFMMLA by +// presenting the single output row as row 0 and a zero row as row 1. +// +// Register ownership: +// x7 - remaining output count +// x8 - input base pointer for current output index +// x9 - output base pointer for current output index +// x10 - input pointer within ic loop +// x11 - packed filter pointer within ic loop +// x12 - remaining input channels +// x17 - InputStrideBytes - (BlockSize * sizeof(float)) +// +// v16-v23 - accumulators for the 8 column pairs +// v0-v1 - FP32 input loads (two k-groups at a time) +// v2/v7 - BF16 inputs for the active k-group (high half kept zero) +// v3-v6 - filter loads and store staging +// + + FUNCTION_ENTRY MlasConvPointwiseBf16KernelNeonSingleOutputAsm + + cbz x6, 199f + + mov x7, x6 // remaining outputs + mov x8, x0 // input base for current output index + mov x9, x2 // output base for current output index + + sub x17, x4, #64 // InputStrideBytes - 64 + +180: // OutputCount loop (process 1 output per iteration) + // Zero accumulators. + eor v16.16b, v16.16b, v16.16b + eor v17.16b, v17.16b, v17.16b + eor v18.16b, v18.16b, v18.16b + eor v19.16b, v19.16b, v19.16b + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + + // Initialize per-iteration pointers. + mov x10, x8 // input row 0 base + mov x11, x1 // packed filter base + mov x12, x5 // remaining input channel blocks + +181: // InputChannels loop (1 output) + // Prefetch the next panels. + prfm pldl1keep, [x10, #256] + prfm pldl1keep, [x11, #512] + + // k-groups 0-1 input conversion (row 0 + implicit zero row 1). + ldp q0, q1, [x10], #32 + eor v2.16b, v2.16b, v2.16b + bfcvtn v2.4h, v0.4s + + // k-group 0, column pairs 0-3. + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x11], #64 + bfmmla v16.4s, v2.8h, v4.8h + bfmmla v17.4s, v2.8h, v5.8h + bfmmla v18.4s, v2.8h, v6.8h + bfmmla v19.4s, v2.8h, v7.8h + + // k-group 0, column pairs 4-7. + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x11], #64 + bfmmla v20.4s, v2.8h, v4.8h + bfmmla v21.4s, v2.8h, v5.8h + bfmmla v22.4s, v2.8h, v6.8h + bfmmla v23.4s, v2.8h, v7.8h + + // k-group 1 (upper BF16 row is kept zero via the prior EOR). + eor v7.16b, v7.16b, v7.16b + bfcvtn v7.4h, v1.4s + + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x11], #64 + bfmmla v16.4s, v7.8h, v3.8h + bfmmla v17.4s, v7.8h, v4.8h + bfmmla v18.4s, v7.8h, v5.8h + bfmmla v19.4s, v7.8h, v6.8h + + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x11], #64 + bfmmla v20.4s, v7.8h, v3.8h + bfmmla v21.4s, v7.8h, v4.8h + bfmmla v22.4s, v7.8h, v5.8h + bfmmla v23.4s, v7.8h, v6.8h + + // k-groups 2-3 input conversion. + ldp q0, q1, [x10], #32 + eor v2.16b, v2.16b, v2.16b + bfcvtn v2.4h, v0.4s + + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x11], #64 + bfmmla v16.4s, v2.8h, v3.8h + bfmmla v17.4s, v2.8h, v4.8h + bfmmla v18.4s, v2.8h, v5.8h + bfmmla v19.4s, v2.8h, v6.8h + + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x11], #64 + bfmmla v20.4s, v2.8h, v3.8h + bfmmla v21.4s, v2.8h, v4.8h + bfmmla v22.4s, v2.8h, v5.8h + bfmmla v23.4s, v2.8h, v6.8h + + eor v7.16b, v7.16b, v7.16b + bfcvtn v7.4h, v1.4s + + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x11], #64 + bfmmla v16.4s, v7.8h, v3.8h + bfmmla v17.4s, v7.8h, v4.8h + bfmmla v18.4s, v7.8h, v5.8h + bfmmla v19.4s, v7.8h, v6.8h + + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x11], #64 + bfmmla v20.4s, v7.8h, v3.8h + bfmmla v21.4s, v7.8h, v4.8h + bfmmla v22.4s, v7.8h, v5.8h + bfmmla v23.4s, v7.8h, v6.8h + + // Advance to the next input channel block. + add x10, x10, x17 + subs x12, x12, #1 + b.ne 181b + + // Store the single output row. Because the second row was zeroed, zip1 + // cleanly gathers the row-0 values without needing zip2. + zip1 v0.2d, v16.2d, v17.2d + zip1 v1.2d, v18.2d, v19.2d + zip1 v2.2d, v20.2d, v21.2d + zip1 v3.2d, v22.2d, v23.2d + + stp q0, q1, [x9], #32 + stp q2, q3, [x9], #32 + + // Advance to the next output. + add x8, x8, x3 + subs x7, x7, #1 + b.ne 180b + +199: + ret + +#endif // __aarch64__ && __ARM_FEATURE_BF16 diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 799170565d3a4..6a6b443f6728c 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1024,7 +1024,10 @@ extern "C" { // AArch64 assembly micro-kernel for pointwise NCHWc convolution MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeonAsm; #endif -#if defined(__aarch64__) && defined(__linux__) +#if defined(__linux__) + // AArch64 assembly fast-math micro-kernels + MLAS_CONV_FLOAT_KERNEL MlasConvNchwBf16KernelNeon; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseBf16KernelNeon; MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon; #endif MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon; @@ -1448,7 +1451,9 @@ struct MLAS_PLATFORM { MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; -#if defined(__aarch64__) && defined(__linux__) +#if defined(__linux__) + MLAS_CONV_FLOAT_KERNEL* ConvNchwBf16Kernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseBf16Kernel; MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseBf16Kernel; #endif MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index eccde79848e61..e7d4bf12aa289 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -583,7 +583,9 @@ Return Value: this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; -#if defined(__aarch64__) && defined(__linux__) +#if defined(__linux__) + this->ConvNchwBf16Kernel = MlasConvNchwBf16KernelNeon; + this->ConvDepthwiseBf16Kernel = MlasConvDepthwiseBf16KernelNeon; this->ConvPointwiseBf16Kernel = MlasConvPointwiseBf16KernelNeon; #endif this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon; diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp index d513791c695f5..b397520e26072 100644 --- a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -14,13 +14,881 @@ Module Name: --*/ -#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) - #include "mlasi.h" + +#if defined(__linux__) && defined(MLAS_USE_ARM_NEON_NCHWC) + +#include +#include +#include + #include "sconv_nchwc_kernel_neon.h" constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; +#if defined(__aarch64__) && defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#define MLAS_CONV_BF16_HELPERS_AVAILABLE +#define MLAS_CONV_BF16_POINTWISE_INTRINSICS_AVAILABLE +#endif + +#if defined(__aarch64__) && !defined(_WIN32) + +struct MLAS_NCHW_BF16_MMLA_PARAMS { + const float* Input; + const uint16_t* PackedFilter; + float* Output; + const float* Bias; + size_t StrideWidthElements; + size_t DilatedInputWidthElements; + size_t OutputCount; + size_t FilterCount; + size_t OutputStrideElements; + unsigned KernelFlags; + unsigned Reserved; +}; + +#define MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(Field, Offset) \ + static_assert(offsetof(MLAS_NCHW_BF16_MMLA_PARAMS, Field) == Offset, #Field " offset mismatch") + +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(Input, 0); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(PackedFilter, 8); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(Output, 16); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(Bias, 24); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(StrideWidthElements, 32); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(DilatedInputWidthElements, 40); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(OutputCount, 48); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(FilterCount, 56); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(OutputStrideElements, 64); +MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT(KernelFlags, 72); + +#undef MLAS_NCHW_BF16_MMLA_PARAM_OFFSET_ASSERT + +extern "C" void +MLASCALL +MlasConvNchwBf16KernelNeonAsm(const MLAS_NCHW_BF16_MMLA_PARAMS* Params); + +extern "C" void +MLASCALL +MlasConvNchwBf16PackFilterNeonAsm( + const float* Filter, + size_t FilterStrideElements, + size_t FilterCount, + uint16_t* PackedFilter + ); + +#endif + +#if defined(MLAS_CONV_BF16_HELPERS_AVAILABLE) + +extern "C" void +MLASCALL +MlasConvBf16OutputPostProcessNeonAsm( + float* Output, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ); + +extern "C" void +MLASCALL +MlasConvDepthwiseBf16KernelNeon3x3DispatchAsm( + const float* Input, + const float* Filter, + float* Output, + size_t OutputCount, + size_t StrideWidthBytes, + size_t DilationWidthBytes, + size_t DilatedInputWidthBytes, + const float* Bias, + unsigned KernelFlags + ); + +extern "C" void +MLASCALL +MlasConvPointwiseBf16PackFilterNeonAsm( + const float* Filter, + size_t InputChannels, + uint16_t* PackedFilter + ); + +extern "C" void +MLASCALL +MlasConvPointwiseBf16KernelNeonAsm( + const float* Input, + const uint16_t* PackedFilter, + float* Output, + size_t StrideWidthBytes, + size_t InputStrideBytes, + size_t InputChannels, + size_t OutputCountEven + ); + +extern "C" void +MLASCALL +MlasConvPointwiseBf16KernelNeonSingleOutputAsm( + const float* Input, + const uint16_t* PackedFilter, + float* Output, + size_t StrideWidthBytes, + size_t InputStrideBytes, + size_t InputChannels, + size_t OutputCount + ); + +extern "C" void +MLASCALL +MlasConvPointwiseBf16PackedInputKernelNeon4xAsm( + const uint16_t* PackedInput, + const uint16_t* PackedFilter, + float* Output, + size_t OutputQuartetCount, + size_t InputChannels + ); + +extern "C" void +MLASCALL +MlasConvPointwiseBf16PackedInputKernelNeon2xAsm( + const uint16_t* PackedInput, + const uint16_t* PackedFilter, + float* Output, + size_t OutputPairCount, + size_t InputChannels + ); + +#endif + +#if defined(MLAS_CONV_BF16_HELPERS_AVAILABLE) + +static inline unsigned +MlasConvBf16SemanticKernelFlags( + unsigned KernelFlags + ) +{ + return KernelFlags & ~MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI; +} + +static inline const float* +MlasConvBf16BiasData( + const float* Bias, + unsigned KernelFlags + ) +{ + return ((KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0) ? Bias : nullptr; +} + +static inline void +MlasConvBf16PostProcessOutputs( + float* Output, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ + const unsigned PostProcessFlags = + KernelFlags & (MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION); + + if (OutputCount == 0 || PostProcessFlags == 0) { + return; + } + +#if defined(__aarch64__) + MlasConvBf16OutputPostProcessNeonAsm( + Output, + OutputCount, + MlasConvBf16BiasData(Bias, PostProcessFlags), + PostProcessFlags + ); + return; +#endif + + const bool BiasAddition = (PostProcessFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (PostProcessFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + float32x4_t BiasVector0 = ZeroVector; + float32x4_t BiasVector1 = ZeroVector; + float32x4_t BiasVector2 = ZeroVector; + float32x4_t BiasVector3 = ZeroVector; + if (BiasAddition) { + BiasVector0 = MlasLoadFloat32x4(Bias); + BiasVector1 = MlasLoadFloat32x4(Bias + 4); + BiasVector2 = MlasLoadFloat32x4(Bias + 8); + BiasVector3 = MlasLoadFloat32x4(Bias + 12); + } + + for (size_t OutputIndex = 0; OutputIndex < OutputCount; ++OutputIndex) { + float* OutputRow = Output + OutputIndex * BlockSize; + + float32x4_t Accumulator0 = MlasLoadFloat32x4(OutputRow); + float32x4_t Accumulator1 = MlasLoadFloat32x4(OutputRow + 4); + float32x4_t Accumulator2 = MlasLoadFloat32x4(OutputRow + 8); + float32x4_t Accumulator3 = MlasLoadFloat32x4(OutputRow + 12); + + if (BiasAddition) { + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(OutputRow, Accumulator0); + MlasStoreFloat32x4(OutputRow + 4, Accumulator1); + MlasStoreFloat32x4(OutputRow + 8, Accumulator2); + MlasStoreFloat32x4(OutputRow + 12, Accumulator3); + } +} + +#endif + +#if defined(MLAS_CONV_BF16_POINTWISE_INTRINSICS_AVAILABLE) + +static inline void +MlasPackPointwiseBf16InputPairNeon( + const float* Input0, + const float* Input1, + uint16_t* PackedInput + ) +{ + for (size_t KGroup = 0; KGroup < BlockSize / 4; ++KGroup) { + bfloat16x8_t InputPair = vcvtq_low_bf16_f32(vld1q_f32(Input0 + KGroup * 4)); + InputPair = vcvtq_high_bf16_f32(InputPair, vld1q_f32(Input1 + KGroup * 4)); + vst1q_u16(PackedInput + KGroup * 8, vreinterpretq_u16_bf16(InputPair)); + } +} + +static inline void +MlasPackPointwiseBf16InputQuartetNeon( + const float* Input0, + const float* Input1, + const float* Input2, + const float* Input3, + uint16_t* PackedInput + ) +{ + MlasPackPointwiseBf16InputPairNeon(Input0, Input1, PackedInput); + MlasPackPointwiseBf16InputPairNeon(Input2, Input3, PackedInput + BlockSize * 2); +} + +static inline void +MlasPackPointwiseBf16InputPairsNeon( + const float* Input, + uint16_t* PackedInput, + size_t StrideWidthElements, + size_t InputStrideElements, + size_t InputChannels, + size_t OutputPairCount + ) +{ + constexpr size_t PointwisePackedInputPairSizeBf16 = BlockSize * 2; + + for (size_t OutputPairIndex = 0; OutputPairIndex < OutputPairCount; ++OutputPairIndex) { + const float* Input0 = Input + OutputPairIndex * 2 * StrideWidthElements; + const float* Input1 = Input0 + StrideWidthElements; + + for (size_t InputChannelIndex = 0; InputChannelIndex < InputChannels; ++InputChannelIndex) { + MlasPackPointwiseBf16InputPairNeon( + Input0 + InputChannelIndex * InputStrideElements, + Input1 + InputChannelIndex * InputStrideElements, + PackedInput + ); + PackedInput += PointwisePackedInputPairSizeBf16; + } + } +} + +static inline void +MlasPackPointwiseBf16InputQuartetsNeon( + const float* Input, + uint16_t* PackedInput, + size_t StrideWidthElements, + size_t InputStrideElements, + size_t InputChannels, + size_t OutputQuartetCount + ) +{ + constexpr size_t PointwisePackedInputQuartetSizeBf16 = BlockSize * 4; + + for (size_t OutputQuartetIndex = 0; OutputQuartetIndex < OutputQuartetCount; ++OutputQuartetIndex) { + const float* Input0 = Input + OutputQuartetIndex * 4 * StrideWidthElements; + const float* Input1 = Input0 + StrideWidthElements; + const float* Input2 = Input1 + StrideWidthElements; + const float* Input3 = Input2 + StrideWidthElements; + + for (size_t InputChannelIndex = 0; InputChannelIndex < InputChannels; ++InputChannelIndex) { + MlasPackPointwiseBf16InputQuartetNeon( + Input0 + InputChannelIndex * InputStrideElements, + Input1 + InputChannelIndex * InputStrideElements, + Input2 + InputChannelIndex * InputStrideElements, + Input3 + InputChannelIndex * InputStrideElements, + PackedInput + ); + PackedInput += PointwisePackedInputQuartetSizeBf16; + } + } +} + +static inline void +MlasMergePointwiseBf16OutputsNeon( + const float* PartialOutput, + float* Output, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(BiasAddition ? -1 : 0)); + const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(ReluActivation ? -1 : 0)); + + float32x4_t BiasVector0 = ZeroVector; + float32x4_t BiasVector1 = ZeroVector; + float32x4_t BiasVector2 = ZeroVector; + float32x4_t BiasVector3 = ZeroVector; + + if (BiasAddition && Bias != nullptr) { + BiasVector0 = MlasLoadFloat32x4(Bias); + BiasVector1 = MlasLoadFloat32x4(Bias + 4); + BiasVector2 = MlasLoadFloat32x4(Bias + 8); + BiasVector3 = MlasLoadFloat32x4(Bias + 12); + } + + for (size_t OutputIndex = 0; OutputIndex < OutputCount; ++OutputIndex) { + const float* PartialRow = PartialOutput + OutputIndex * BlockSize; + float* OutputRow = Output + OutputIndex * BlockSize; + + float32x4_t Accumulator0 = MlasAddFloat32x4(MlasLoadFloat32x4(PartialRow), MlasLoadFloat32x4(OutputRow)); + float32x4_t Accumulator1 = MlasAddFloat32x4(MlasLoadFloat32x4(PartialRow + 4), MlasLoadFloat32x4(OutputRow + 4)); + float32x4_t Accumulator2 = MlasAddFloat32x4(MlasLoadFloat32x4(PartialRow + 8), MlasLoadFloat32x4(OutputRow + 8)); + float32x4_t Accumulator3 = MlasAddFloat32x4(MlasLoadFloat32x4(PartialRow + 12), MlasLoadFloat32x4(OutputRow + 12)); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, MlasAndFloat32x4(BiasVector0, BiasMask)); + Accumulator1 = MlasAddFloat32x4(Accumulator1, MlasAndFloat32x4(BiasVector1, BiasMask)); + Accumulator2 = MlasAddFloat32x4(Accumulator2, MlasAndFloat32x4(BiasVector2, BiasMask)); + Accumulator3 = MlasAddFloat32x4(Accumulator3, MlasAndFloat32x4(BiasVector3, BiasMask)); + + float32x4_t Relu0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + float32x4_t Relu1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + float32x4_t Relu2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + float32x4_t Relu3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + + Accumulator0 = MlasBlendFloat32x4(Accumulator0, Relu0, ReluMask); + Accumulator1 = MlasBlendFloat32x4(Accumulator1, Relu1, ReluMask); + Accumulator2 = MlasBlendFloat32x4(Accumulator2, Relu2, ReluMask); + Accumulator3 = MlasBlendFloat32x4(Accumulator3, Relu3, ReluMask); + + MlasStoreFloat32x4(OutputRow, Accumulator0); + MlasStoreFloat32x4(OutputRow + 4, Accumulator1); + MlasStoreFloat32x4(OutputRow + 8, Accumulator2); + MlasStoreFloat32x4(OutputRow + 12, Accumulator3); + } +} + +#endif + +namespace { + +#if defined(MLAS_CONV_BF16_POINTWISE_INTRINSICS_AVAILABLE) + +static void +MlasConvPointwiseFloatKernelNeonBf16Mmla( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount + ) +{ + constexpr size_t PointwiseInputChannelsMax = 8; + constexpr size_t PointwiseFilterCountMax = 4; + constexpr size_t PointwisePackedFilterStrideBf16 = 256; + constexpr size_t PointwisePackedFilterSizeBf16 = PointwiseInputChannelsMax * PointwisePackedFilterStrideBf16; + constexpr size_t PointwisePackedInputQuartetSizeBf16 = BlockSize * 4; + constexpr size_t PointwisePackedInputPairSizeBf16 = BlockSize * 2; + constexpr size_t PointwiseOutputQuartetBatchMax = 8; + constexpr size_t PointwiseOutputPairBatchMax = 16; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + + const size_t OutputCountQuartet = OutputCount & ~size_t{3}; + const size_t OutputCountEven = OutputCount & ~size_t{1}; + const size_t OutputCountPairTail = OutputCountEven - OutputCountQuartet; + const size_t OutputCountRemainder = OutputCount - OutputCountEven; + const float* InputRemainder = Input + OutputCountEven * StrideWidthElements; + + alignas(16) uint16_t PackedFilters[PointwiseFilterCountMax * PointwisePackedFilterSizeBf16]; + + for (size_t f = 0; f < FilterCount; ++f) { + const float* filter = Filter + f * FilterStrideElements; + uint16_t* packed_filter = PackedFilters + f * PointwisePackedFilterSizeBf16; + MlasConvPointwiseBf16PackFilterNeonAsm(filter, InputChannels, packed_filter); + } + + if (FilterCount > 1) { + if (OutputCountQuartet != 0) { + alignas(16) uint16_t PackedInput[PointwiseOutputQuartetBatchMax * PointwiseInputChannelsMax * PointwisePackedInputQuartetSizeBf16]; + + size_t OutputIndex = 0; + while (OutputIndex < OutputCountQuartet) { + const size_t OutputQuartetCount = std::min( + (OutputCountQuartet - OutputIndex) / 4, + PointwiseOutputQuartetBatchMax + ); + + MlasPackPointwiseBf16InputQuartetsNeon( + Input + OutputIndex * StrideWidthElements, + PackedInput, + StrideWidthElements, + InputStrideElements, + InputChannels, + OutputQuartetCount + ); + + for (size_t f = 0; f < FilterCount; ++f) { + MlasConvPointwiseBf16PackedInputKernelNeon4xAsm( + PackedInput, + PackedFilters + f * PointwisePackedFilterSizeBf16, + Output + f * OutputStrideElements + OutputIndex * BlockSize, + OutputQuartetCount, + InputChannels + ); + } + + OutputIndex += OutputQuartetCount * 4; + } + } + + if (OutputCountPairTail != 0) { + alignas(16) uint16_t PackedInput[PointwiseOutputPairBatchMax * PointwiseInputChannelsMax * PointwisePackedInputPairSizeBf16]; + + MlasPackPointwiseBf16InputPairsNeon( + Input + OutputCountQuartet * StrideWidthElements, + PackedInput, + StrideWidthElements, + InputStrideElements, + InputChannels, + OutputCountPairTail / 2 + ); + + for (size_t f = 0; f < FilterCount; ++f) { + MlasConvPointwiseBf16PackedInputKernelNeon2xAsm( + PackedInput, + PackedFilters + f * PointwisePackedFilterSizeBf16, + Output + f * OutputStrideElements + OutputCountQuartet * BlockSize, + OutputCountPairTail / 2, + InputChannels + ); + } + } + } else { + for (size_t f = 0; f < FilterCount; ++f) { + float* output = Output + f * OutputStrideElements; + + if (OutputCountEven != 0) { + MlasConvPointwiseBf16KernelNeonAsm( + Input, + PackedFilters + f * PointwisePackedFilterSizeBf16, + output, + StrideWidth, + InputStride, + InputChannels, + OutputCountEven); + } + } + } + + if (OutputCountRemainder != 0) { + for (size_t f = 0; f < FilterCount; ++f) { + float* output_remainder = Output + f * OutputStrideElements + OutputCountEven * BlockSize; + MlasConvPointwiseBf16KernelNeonSingleOutputAsm( + InputRemainder, + PackedFilters + f * PointwisePackedFilterSizeBf16, + output_remainder, + StrideWidth, + InputStride, + InputChannels, + OutputCountRemainder); + } + } +} + +#endif + +} // namespace + +void +MLASCALL +MlasConvNchwBf16KernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ +#if defined(__aarch64__) && !defined(_WIN32) + constexpr size_t Bf16MmlaKernelHeight = 3; + constexpr size_t Bf16MmlaKernelWidth = 3; + constexpr size_t Bf16MmlaMaxFilterCount = 4; + constexpr size_t Bf16MmlaPackedFilterStrideBf16 = 192; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + const bool UseBf16MmlaKernel = + KernelHeight == Bf16MmlaKernelHeight && + KernelWidth == Bf16MmlaKernelWidth && + DilationWidthElements == 1 && + FilterCount > 0 && FilterCount <= Bf16MmlaMaxFilterCount && + OutputCount >= 2; + + if (UseBf16MmlaKernel) { + auto ConvFallbackSegment = [&](size_t outputOffset, size_t segmentCount) { + if (segmentCount == 0) { + return; + } + + MlasConvNchwFloatKernelNeon( + Input + outputOffset * StrideWidthElements, + Filter, + Output + outputOffset * BlockSize, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + 0, + segmentCount, + 0, + Bias, + KernelFlags + ); + }; + + size_t OutputOffset = 0; + + ConvFallbackSegment(OutputOffset, OutputCountLeftPad); + OutputOffset += OutputCountLeftPad; + + const size_t OutputCountEven = OutputCount & ~size_t{1}; + + if (OutputCountEven != 0) { + alignas(16) uint16_t PackedFilter[Bf16MmlaMaxFilterCount * Bf16MmlaPackedFilterStrideBf16]; + MlasConvNchwBf16PackFilterNeonAsm(Filter, FilterStrideElements, FilterCount, PackedFilter); + + MLAS_NCHW_BF16_MMLA_PARAMS Params; + Params.Input = Input + OutputOffset * StrideWidthElements; + Params.PackedFilter = PackedFilter; + Params.Output = Output + OutputOffset * BlockSize; + Params.Bias = Bias; + Params.StrideWidthElements = StrideWidthElements; + Params.DilatedInputWidthElements = DilatedInputWidthElements; + Params.OutputCount = OutputCountEven; + Params.FilterCount = FilterCount; + Params.OutputStrideElements = OutputStrideElements; + Params.KernelFlags = KernelFlags; + Params.Reserved = 0; + + MlasConvNchwBf16KernelNeonAsm(&Params); + + OutputOffset += OutputCountEven; + } + + ConvFallbackSegment(OutputOffset, OutputCount - OutputCountEven); + OutputOffset += OutputCount - OutputCountEven; + + ConvFallbackSegment(OutputOffset, OutputCountRightPad); + return; + } +#endif + + MlasConvNchwFloatKernelNeon( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +void +MLASCALL +MlasConvDepthwiseBf16KernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ +#if defined(__aarch64__) && defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const unsigned SemanticKernelFlags = MlasConvBf16SemanticKernelFlags(KernelFlags); + const bool AccumulateOutput = (SemanticKernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + + constexpr bool BlockSizeSupported = (BlockSize == 16); + const bool CanUseAsmHotPath = + BlockSizeSupported && + !AccumulateOutput && + KernelHeight == 3 && + KernelWidth == 3 && + DilationWidthElements == BlockSize && + DilatedInputWidth > 2 * DilationWidth && + OutputCount > 0; + + if (CanUseAsmHotPath) { + if (OutputCountLeftPad != 0) { + MlasConvDepthwiseFloatKernelNeon( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + InputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + 0, + OutputCountLeftPad, + 0, + Bias, + KernelFlags + ); + } + + MlasConvDepthwiseBf16KernelNeon3x3DispatchAsm( + Input + OutputCountLeftPad * StrideWidthElements, + Filter, + Output + OutputCountLeftPad * BlockSize, + OutputCount, + StrideWidth, + DilationWidth, + DilatedInputWidth, + MlasConvBf16BiasData(Bias, SemanticKernelFlags), + SemanticKernelFlags + ); + + if (OutputCountRightPad != 0) { + const size_t RightOutputStart = OutputCountLeftPad + OutputCount; + MlasConvDepthwiseFloatKernelNeon( + Input + RightOutputStart * StrideWidthElements, + Filter, + Output + RightOutputStart * BlockSize, + StrideWidth, + DilationWidth, + InputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + 0, + OutputCountRightPad, + 0, + Bias, + KernelFlags + ); + } + + return; + } +#endif + + MlasConvDepthwiseFloatKernelNeon( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + InputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +bool +MLASCALL +MlasTryConvPointwiseBf16KernelNeonAsm( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ +#if defined(MLAS_CONV_BF16_POINTWISE_INTRINSICS_AVAILABLE) + constexpr size_t PointwiseFilterCountMax = 4; + constexpr size_t PointwiseInputChannelsMax = 8; + + const unsigned SemanticKernelFlags = MlasConvBf16SemanticKernelFlags(KernelFlags); + const bool AccumulateOutput = (SemanticKernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + + if (BlockSize != 16 || + OutputCount == 0 || + InputChannels == 0 || InputChannels > PointwiseInputChannelsMax || + FilterCount == 0 || FilterCount > PointwiseFilterCountMax) { + return false; + } + + float* KernelOutput = Output; + size_t KernelOutputStride = OutputStride; + size_t ScratchOutputStrideElements = 0; + + if (AccumulateOutput) { + ScratchOutputStrideElements = OutputCount * BlockSize; + const size_t ScratchOutputBytes = UpAlignSize(FilterCount * ScratchOutputStrideElements * sizeof(float)); + MlasThreadedBufAlloc(ScratchOutputBytes); + + if (ThreadedBufHolder.get() == nullptr) { + return false; + } + + KernelOutput = reinterpret_cast(ThreadedBufHolder.get()); + KernelOutputStride = ScratchOutputStrideElements * sizeof(float); + } + + MlasConvPointwiseFloatKernelNeonBf16Mmla( + Input, + Filter, + KernelOutput, + StrideWidth, + InputChannels, + FilterCount, + InputStride, + FilterStride, + KernelOutputStride, + OutputCount); + + const unsigned PostProcessFlags = + SemanticKernelFlags & (MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION); + + if (AccumulateOutput) { + const size_t OutputStrideElements = OutputStride / sizeof(float); + const float* BiasData = MlasConvBf16BiasData(Bias, PostProcessFlags); + + for (size_t f = 0; f < FilterCount; ++f) { + MlasMergePointwiseBf16OutputsNeon( + KernelOutput + f * ScratchOutputStrideElements, + Output + f * OutputStrideElements, + OutputCount, + BiasData == nullptr ? nullptr : BiasData + f * BlockSize, + PostProcessFlags + ); + } + } else if (PostProcessFlags != 0) { + const size_t OutputStrideElements = OutputStride / sizeof(float); + const float* BiasData = MlasConvBf16BiasData(Bias, PostProcessFlags); + + for (size_t f = 0; f < FilterCount; ++f) { + MlasConvBf16PostProcessOutputs( + Output + f * OutputStrideElements, + OutputCount, + BiasData == nullptr ? nullptr : BiasData + f * BlockSize, + PostProcessFlags + ); + } + } + + return true; +#else + MLAS_UNREFERENCED_PARAMETER(Input); + MLAS_UNREFERENCED_PARAMETER(Filter); + MLAS_UNREFERENCED_PARAMETER(Output); + MLAS_UNREFERENCED_PARAMETER(StrideWidth); + MLAS_UNREFERENCED_PARAMETER(InputChannels); + MLAS_UNREFERENCED_PARAMETER(FilterCount); + MLAS_UNREFERENCED_PARAMETER(InputStride); + MLAS_UNREFERENCED_PARAMETER(FilterStride); + MLAS_UNREFERENCED_PARAMETER(OutputStride); + MLAS_UNREFERENCED_PARAMETER(OutputCount); + MLAS_UNREFERENCED_PARAMETER(Bias); + MLAS_UNREFERENCED_PARAMETER(KernelFlags); + return false; +#endif +} + // // BF16 Pointwise (1x1) Convolution Kernel using SBGEMM. // @@ -44,6 +912,22 @@ MlasConvPointwiseBf16KernelNeon( const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + if (MlasTryConvPointwiseBf16KernelNeonAsm( + Input, + Filter, + Output, + StrideWidth, + InputChannels, + FilterCount, + InputStride, + FilterStride, + OutputStride, + OutputCount, + Bias, + KernelFlags)) { + return; + } + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; // TODO(hasesh): With the ARM KleidiAI team, study the impact of using the KleidiAI SBGEMM kernel for this convolution kernel @@ -115,4 +999,7 @@ MlasConvPointwiseBf16KernelNeon( } } +#undef MLAS_CONV_BF16_POINTWISE_INTRINSICS_AVAILABLE +#undef MLAS_CONV_BF16_HELPERS_AVAILABLE + #endif diff --git a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp index 6ca2398f91503..bd9dd737f6e98 100644 --- a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp @@ -50,8 +50,9 @@ void const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); const float32x4_t AccumulateMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT))); const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; - const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-static_cast(BiasAddition))); - const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION))); + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(BiasAddition ? -1 : 0)); + const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(ReluActivation ? -1 : 0)); const size_t StrideWidthElements = StrideWidth / sizeof(float); const size_t DilationWidthElements = DilationWidth / sizeof(float); @@ -314,13 +315,11 @@ void // Implementation of MlasConvDepthwiseFloatKernelNeon // // This kernel performs depthwise separable convolution where each input channel -// is convolved with its own filter. This is more efficient than standard convolution -// for certain network architectures like MobileNets. +// is convolved with its own filter. // -void - MLASCALL - MlasConvDepthwiseFloatKernelNeon( +static void +MlasConvDepthwiseFloatKernelNeonImpl( const float* Input, const float* Filter, float* Output, @@ -341,8 +340,10 @@ void { const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); const float32x4_t AccumulateMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT))); - const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION))); - const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION))); + const bool BiasAdditionEnabled = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(BiasAdditionEnabled ? -1 : 0)); + const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(ReluActivation ? -1 : 0)); const size_t StrideWidthElements = StrideWidth / sizeof(float); const size_t DilationWidthElements = DilationWidth / sizeof(float); @@ -354,7 +355,9 @@ void const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; - for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + const size_t KernelSize = KernelHeight * KernelWidth; + + auto ComputeDepthwiseOutput = [&](size_t output_idx) { float32x4_t OldOutput0 = MlasLoadFloat32x4(&Output[output_idx * BlockSize]); float32x4_t OldOutput1 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 4]); @@ -378,7 +381,7 @@ void BiasVector3 = MlasLoadFloat32x4(Bias + 12); } - for (size_t kernel_pos = 0; kernel_pos < KernelHeight * KernelWidth; kernel_pos++) { + for (size_t kernel_pos = 0; kernel_pos < KernelSize; kernel_pos++) { size_t kh = kernel_pos / KernelWidth; size_t kw = kernel_pos % KernelWidth; @@ -425,33 +428,75 @@ void MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], Accumulator1); MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], Accumulator2); MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], Accumulator3); + + }; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + ComputeDepthwiseOutput(output_idx); } } -// -// Implementation of MlasConvPointwiseFloatKernelNeon -// -// Performs pointwise (1x1) convolution on NCHWC formatted data using batched -// GEMM. Input channels are strided by InputStride, requiring separate GEMMs -// per channel block which are accumulated into the output. -// - void MLASCALL - MlasConvPointwiseFloatKernelNeon( + MlasConvDepthwiseFloatKernelNeon( const float* Input, const float* Filter, float* Output, size_t StrideWidth, - size_t InputChannels, - size_t FilterCount, + size_t DilationWidth, size_t InputStride, - size_t FilterStride, - size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, size_t OutputCount, + size_t OutputCountRightPad, const float* Bias, unsigned KernelFlags ) +{ + MlasConvDepthwiseFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + InputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags); +} + +// +// Pointwise convolution helpers. +// + +namespace { + +static void +MlasConvPointwiseFloatKernelNeonFallback( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) { const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; @@ -535,4 +580,38 @@ void } } +} // namespace + +void + MLASCALL + MlasConvPointwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvPointwiseFloatKernelNeonFallback( + Input, + Filter, + Output, + StrideWidth, + InputChannels, + FilterCount, + InputStride, + FilterStride, + OutputStride, + OutputCount, + Bias, + KernelFlags); +} + #endif diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 8c4e3cf8fae42..fb816e0d1d2f7 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -795,6 +795,12 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; +#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) + MLAS_CONV_FLOAT_KERNEL* const KernelFloat = GetMlasPlatform().ConvNchwFloatKernel; + if (WorkBlock->UseBf16) { + Kernel = GetMlasPlatform().ConvNchwBf16Kernel; + } +#endif #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; #endif @@ -813,6 +819,26 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM ComputeEffectiveKernel(ph, BlockSize * KernelWidth, &filter, &ih, &EffectiveKernelHeight); + MLAS_CONV_FLOAT_KERNEL* KernelToUse = Kernel; +#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) + if (WorkBlock->UseBf16 && + EffectiveKernelHeight == 3 && + KernelWidth == 3) { + // + // The current direct BF16 asm uses a two-output load pattern + // that reads one float past the end of the third source row. + // That is valid for interior rows because the next row is + // contiguous in memory, but it would step into the guard page + // on the final source row of the image. + // + const bool HasTrailingSourceRow = + (ih + (EffectiveKernelHeight - 1) * DilationHeight + 1) < InputHeight; + if (!HasTrailingSourceRow) { + KernelToUse = KernelFloat; + } + } +#endif + // // Apply the convolution kernel to each channel of the input tensor. // @@ -828,7 +854,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM // Invoke the convolution kernel. // - Kernel(input + (ih * InputWidth - PaddingLeftX), filter, output, + KernelToUse(input + (ih * InputWidth - PaddingLeftX), filter, output, StrideWidthBytes, DilationWidthBytes, FilterCount, InputStrideBytes, FilterStrideBytes, OutputStrideBytes, EffectiveKernelHeight, KernelWidth, input + (ih * InputWidth), InputWidthBytes, @@ -895,7 +921,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM // output positions at once and significantly reduces memory traffic. MLAS_CONV_POINTWISE_FLOAT_KERNEL* const KernelFast = MlasConvPointwiseFloatKernelNeonAsm; #endif -#if defined(__aarch64__) && defined(__linux__) +#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) if (WorkBlock->UseBf16) { Kernel = GetMlasPlatform().ConvPointwiseBf16Kernel; } @@ -1048,6 +1074,11 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM #if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* const KernelFast = MlasConvDepthwiseFloatKernelNeonAsm; #endif +#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) + if (WorkBlock->UseBf16) { + Kernel = GetMlasPlatform().ConvDepthwiseBf16Kernel; + } +#endif #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; #endif @@ -1073,7 +1104,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* KernelToUse = Kernel; #if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32) - if (OutputWidth >= 4) { + if (!WorkBlock->UseBf16 && OutputWidth >= 4) { KernelToUse = KernelFast; } #endif diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h index 04d32ab210c7b..d49a67bee3b18 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h @@ -251,6 +251,21 @@ class MlasNchwcConv2DBf16Test : public MlasNchwcConv2DTest { } } } + + // BF16 direct NCHW tests (InputChannels < BlockSize, non-depthwise). + // These shapes are eligible for the direct 3x3 BF16 asm kernel and route + // through ConvNchwBf16Kernel in snchwc.cpp. + TestBf16(1, 1, 3, 8, 8, 4, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); + TestBf16(2, 1, 5, 9, 9, 3, 3, 3, 1, 1, 1, 1, 1, 1, 2, 2); + TestBf16(1, 1, 7, 11, 11, 2, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); + + // BF16 depthwise tests (grouped conv with one channel/filter per group). + // The padded case exercises the depthwise wrapper's left/right fallback + // around the BF16 asm interior span. + TestBf16(1, 16, 1, 8, 8, 1, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); + TestBf16(1, 32, 1, 10, 10, 1, 3, 3, 0, 0, 0, 0, 1, 1, 2, 2); + TestBf16(1, 16, 1, 10, 10, 1, 3, 3, 0, 0, 0, 0, 1, 1, 3, 3); + TestBf16(1, 16, 1, 8, 8, 1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); } private: