From 22dbed4571b2eac59448692c438a120a99cecf79 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Fri, 20 Mar 2026 21:05:35 +0000 Subject: [PATCH 1/6] mlas/arm64: Add AArch64 assembly path for NCHWc float kernel and wire into build Signed-off-by: Milos Puzovic --- cmake/onnxruntime_mlas.cmake | 1 + .../mlas/lib/aarch64/SconvNchwcKernelNeon.S | 1291 +++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 4 + .../core/mlas/lib/sconv_nchwc_kernel_neon.cpp | 24 + 4 files changed, 1320 insertions(+) create mode 100644 onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 0156e46b86bc4..188fdab05dafd 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -336,6 +336,7 @@ function (setup_arm_neon_nchwc) # separate assembly file allows tighter control over register allocation # and avoids the overhead of C++/intrinsics based code generation. ${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvNchwcKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S ) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S new file mode 100644 index 0000000000000..0ee4ee6ab1bbc --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S @@ -0,0 +1,1291 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + + SconvNchwcKernelNeon.S + +Abstract: + + This module implements the single precision NCHWC convolution kernel for + AArch64 processors with NEON support. + +--*/ + +#if defined(__aarch64__) && !defined(_WIN32) + + .text + .align 2 + +// +// PROCESS_LANE - Load one 16-float filter row and accumulate it using a +// single input lane. +// + + .macro PROCESS_LANE offset0, offset1, inreg, lane + ldp q16, q17, [x7, #\offset0] + ldp q18, q19, [x7, #\offset1] + fmla v0.4s, v16.4s, \inreg\().s[\lane] + fmla v1.4s, v17.4s, \inreg\().s[\lane] + fmla v2.4s, v18.4s, \inreg\().s[\lane] + fmla v3.4s, v19.4s, \inreg\().s[\lane] + .endm + +// +// PROCESS_LANE_2OUT - Load one 16-float filter row and accumulate it using +// two input lanes for a dual-output path. +// + + .macro PROCESS_LANE_2OUT offset0, offset1, inreg0, inreg1, lane + ldp q16, q17, [x7, #\offset0] + ldp q18, q19, [x7, #\offset1] + fmla v0.4s, v16.4s, \inreg0\().s[\lane] + fmla v1.4s, v17.4s, \inreg0\().s[\lane] + fmla v2.4s, v18.4s, \inreg0\().s[\lane] + fmla v3.4s, v19.4s, \inreg0\().s[\lane] + fmla v20.4s, v16.4s, \inreg1\().s[\lane] + fmla v21.4s, v17.4s, \inreg1\().s[\lane] + fmla v22.4s, v18.4s, \inreg1\().s[\lane] + fmla v23.4s, v19.4s, \inreg1\().s[\lane] + .endm + +// +// PROCESS_LANE_3OUT - Load one 16-float filter row and accumulate it using +// three input lanes for a tri-output path. +// + + .macro PROCESS_LANE_3OUT offset0, offset1, inreg0, inreg1, inreg2, lane + ldp q16, q17, [x7, #\offset0] + ldp q18, q19, [x7, #\offset1] + fmla v0.4s, v16.4s, \inreg0\().s[\lane] + fmla v1.4s, v17.4s, \inreg0\().s[\lane] + fmla v2.4s, v18.4s, \inreg0\().s[\lane] + fmla v3.4s, v19.4s, \inreg0\().s[\lane] + fmla v4.4s, v16.4s, \inreg1\().s[\lane] + fmla v5.4s, v17.4s, \inreg1\().s[\lane] + fmla v6.4s, v18.4s, \inreg1\().s[\lane] + fmla v7.4s, v19.4s, \inreg1\().s[\lane] + fmla v8.4s, v16.4s, \inreg2\().s[\lane] + fmla v9.4s, v17.4s, \inreg2\().s[\lane] + fmla v10.4s, v18.4s, \inreg2\().s[\lane] + fmla v11.4s, v19.4s, \inreg2\().s[\lane] + .endm + +// +// PROCESS_LANE_4OUT - Load one 16-float filter row and accumulate it using +// four input lanes for a quad-output path. +// + + .macro PROCESS_LANE_4OUT offset0, offset1, inreg0, inreg1, inreg2, inreg3, lane + ldp q16, q17, [x7, #\offset0] + ldp q18, q19, [x7, #\offset1] + fmla v0.4s, v16.4s, \inreg0\().s[\lane] + fmla v1.4s, v17.4s, \inreg0\().s[\lane] + fmla v2.4s, v18.4s, \inreg0\().s[\lane] + fmla v3.4s, v19.4s, \inreg0\().s[\lane] + fmla v4.4s, v16.4s, \inreg1\().s[\lane] + fmla v5.4s, v17.4s, \inreg1\().s[\lane] + fmla v6.4s, v18.4s, \inreg1\().s[\lane] + fmla v7.4s, v19.4s, \inreg1\().s[\lane] + fmla v8.4s, v16.4s, \inreg2\().s[\lane] + fmla v9.4s, v17.4s, \inreg2\().s[\lane] + fmla v10.4s, v18.4s, \inreg2\().s[\lane] + fmla v11.4s, v19.4s, \inreg2\().s[\lane] + fmla v12.4s, v16.4s, \inreg3\().s[\lane] + fmla v13.4s, v17.4s, \inreg3\().s[\lane] + fmla v14.4s, v18.4s, \inreg3\().s[\lane] + fmla v15.4s, v19.4s, \inreg3\().s[\lane] + .endm + +// +// CONV_LOOP_PAD - Convolution loop with per-position bounds checks. This +// path is used for the padded output regions on the left and right edges. +// + + .macro CONV_LOOP_PAD + mov x7, x4 // Filter pointer + mov x8, x1 // Input row pointer + mov x9, x13 // Row start pointer + mov x10, x27 // KernelHeight counter + + cbz x10, .Lconv_done_pad_\@ + cbz x28, .Lconv_done_pad_\@ + + .Lkh_loop_pad_\@: + mov x12, x8 // Input pointer for width + mov x16, x28 // KernelWidth counter + + .Lkw_loop_pad_\@: + // Branch if the input pointer lies outside [row_start, row_start + row_width). + sub x11, x12, x9 + cmp x11, x14 + b.hs .Lkw_skip_pad_\@ + + ldp q4, q5, [x12, #0] + ldp q6, q7, [x12, #32] + + PROCESS_LANE 0, 32, v4, 0 + PROCESS_LANE 64, 96, v4, 1 + PROCESS_LANE 128, 160, v4, 2 + PROCESS_LANE 192, 224, v4, 3 + PROCESS_LANE 256, 288, v5, 0 + PROCESS_LANE 320, 352, v5, 1 + PROCESS_LANE 384, 416, v5, 2 + PROCESS_LANE 448, 480, v5, 3 + PROCESS_LANE 512, 544, v6, 0 + PROCESS_LANE 576, 608, v6, 1 + PROCESS_LANE 640, 672, v6, 2 + PROCESS_LANE 704, 736, v6, 3 + PROCESS_LANE 768, 800, v7, 0 + PROCESS_LANE 832, 864, v7, 1 + PROCESS_LANE 896, 928, v7, 2 + PROCESS_LANE 960, 992, v7, 3 + + .Lkw_skip_pad_\@: + add x7, x7, #1024 + add x12, x12, x23 + subs x16, x16, #1 + b.ne .Lkw_loop_pad_\@ + + add x9, x9, x15 + add x8, x8, x15 + subs x10, x10, #1 + b.ne .Lkh_loop_pad_\@ + + .Lconv_done_pad_\@: + .endm + +// +// CONV_LOOP_MID - Convolution loop without bounds checks. Output positions in +// the middle region are guaranteed to be fully in-bounds. +// + + .macro CONV_LOOP_MID + mov x7, x4 // Filter pointer + mov x8, x1 // Input row pointer + mov x10, x27 // KernelHeight counter + + cbz x10, .Lconv_done_mid_\@ + cbz x28, .Lconv_done_mid_\@ + + .Lkh_loop_mid_\@: + mov x12, x8 // Input pointer for width + mov x16, x28 // KernelWidth counter + + .Lkw_loop_mid_\@: + ldp q4, q5, [x12, #0] + ldp q6, q7, [x12, #32] + + PROCESS_LANE 0, 32, v4, 0 + PROCESS_LANE 64, 96, v4, 1 + PROCESS_LANE 128, 160, v4, 2 + PROCESS_LANE 192, 224, v4, 3 + PROCESS_LANE 256, 288, v5, 0 + PROCESS_LANE 320, 352, v5, 1 + PROCESS_LANE 384, 416, v5, 2 + PROCESS_LANE 448, 480, v5, 3 + PROCESS_LANE 512, 544, v6, 0 + PROCESS_LANE 576, 608, v6, 1 + PROCESS_LANE 640, 672, v6, 2 + PROCESS_LANE 704, 736, v6, 3 + PROCESS_LANE 768, 800, v7, 0 + PROCESS_LANE 832, 864, v7, 1 + PROCESS_LANE 896, 928, v7, 2 + PROCESS_LANE 960, 992, v7, 3 + + add x7, x7, #1024 + add x12, x12, x23 + subs x16, x16, #1 + b.ne .Lkw_loop_mid_\@ + + add x8, x8, x15 + subs x10, x10, #1 + b.ne .Lkh_loop_mid_\@ + + .Lconv_done_mid_\@: + .endm + +// +// CONV_LOOP_MID_3OUT - Convolution loop without bounds checks that computes +// three adjacent output points per iteration. +// + + .macro CONV_LOOP_MID_3OUT + mov x7, x4 // Filter pointer + mov x8, x1 // Input row pointer (output 0) + mov x9, x27 // KernelHeight counter + + cbz x9, .Lconv_done_mid3_\@ + cbz x28, .Lconv_done_mid3_\@ + + .Lkh_loop_mid3_\@: + mov x12, x8 // Input pointer for width (output 0) + mov x16, x28 // KernelWidth counter + + .Lkw_loop_mid3_\@: + add x11, x12, x22 // Output 1 input pointer + add x10, x12, x22, lsl #1 // Output 2 input pointer + + ldp q20, q21, [x12, #0] + ldp q22, q23, [x12, #32] + ldp q24, q25, [x11, #0] + ldp q26, q27, [x11, #32] + ldp q12, q13, [x10, #0] + ldp q14, q15, [x10, #32] + + PROCESS_LANE_3OUT 0, 32, v20, v24, v12, 0 + PROCESS_LANE_3OUT 64, 96, v20, v24, v12, 1 + PROCESS_LANE_3OUT 128, 160, v20, v24, v12, 2 + PROCESS_LANE_3OUT 192, 224, v20, v24, v12, 3 + PROCESS_LANE_3OUT 256, 288, v21, v25, v13, 0 + PROCESS_LANE_3OUT 320, 352, v21, v25, v13, 1 + PROCESS_LANE_3OUT 384, 416, v21, v25, v13, 2 + PROCESS_LANE_3OUT 448, 480, v21, v25, v13, 3 + PROCESS_LANE_3OUT 512, 544, v22, v26, v14, 0 + PROCESS_LANE_3OUT 576, 608, v22, v26, v14, 1 + PROCESS_LANE_3OUT 640, 672, v22, v26, v14, 2 + PROCESS_LANE_3OUT 704, 736, v22, v26, v14, 3 + PROCESS_LANE_3OUT 768, 800, v23, v27, v15, 0 + PROCESS_LANE_3OUT 832, 864, v23, v27, v15, 1 + PROCESS_LANE_3OUT 896, 928, v23, v27, v15, 2 + PROCESS_LANE_3OUT 960, 992, v23, v27, v15, 3 + + add x7, x7, #1024 + add x12, x12, x23 + subs x16, x16, #1 + b.ne .Lkw_loop_mid3_\@ + + add x8, x8, x15 + subs x9, x9, #1 + b.ne .Lkh_loop_mid3_\@ + + .Lconv_done_mid3_\@: + .endm + +// +// CONV_LOOP_MID_2OUT - Convolution loop without bounds checks that computes +// two adjacent output points per iteration. +// + + .macro CONV_LOOP_MID_2OUT + mov x7, x4 // Filter pointer + mov x8, x1 // Input row pointer (output 0) + add x9, x1, x22 // Input row pointer (output 1) + mov x10, x27 // KernelHeight counter + + cbz x10, .Lconv_done_mid2_\@ + cbz x28, .Lconv_done_mid2_\@ + + .Lkh_loop_mid2_\@: + mov x12, x8 // Input pointer for width (output 0) + mov x11, x9 // Input pointer for width (output 1) + mov x16, x28 // KernelWidth counter + + .Lkw_loop_mid2_\@: + ldp q4, q5, [x12, #0] + ldp q6, q7, [x12, #32] + ldp q24, q25, [x11, #0] + ldp q26, q27, [x11, #32] + + PROCESS_LANE_2OUT 0, 32, v4, v24, 0 + PROCESS_LANE_2OUT 64, 96, v4, v24, 1 + PROCESS_LANE_2OUT 128, 160, v4, v24, 2 + PROCESS_LANE_2OUT 192, 224, v4, v24, 3 + PROCESS_LANE_2OUT 256, 288, v5, v25, 0 + PROCESS_LANE_2OUT 320, 352, v5, v25, 1 + PROCESS_LANE_2OUT 384, 416, v5, v25, 2 + PROCESS_LANE_2OUT 448, 480, v5, v25, 3 + PROCESS_LANE_2OUT 512, 544, v6, v26, 0 + PROCESS_LANE_2OUT 576, 608, v6, v26, 1 + PROCESS_LANE_2OUT 640, 672, v6, v26, 2 + PROCESS_LANE_2OUT 704, 736, v6, v26, 3 + PROCESS_LANE_2OUT 768, 800, v7, v27, 0 + PROCESS_LANE_2OUT 832, 864, v7, v27, 1 + PROCESS_LANE_2OUT 896, 928, v7, v27, 2 + PROCESS_LANE_2OUT 960, 992, v7, v27, 3 + + add x7, x7, #1024 + add x12, x12, x23 + add x11, x11, x23 + subs x16, x16, #1 + b.ne .Lkw_loop_mid2_\@ + + add x8, x8, x15 + add x9, x9, x15 + subs x10, x10, #1 + b.ne .Lkh_loop_mid2_\@ + + .Lconv_done_mid2_\@: + .endm + +// +// CONV_LOOP_MID_4OUT - Convolution loop without bounds checks that computes +// four adjacent output points per iteration. This path is used by the +// KernelFlags==0 fast path to reduce filter load pressure. +// + + .macro CONV_LOOP_MID_4OUT + mov x7, x4 // Filter pointer + mov x8, x1 // Input row pointer (output 0) + mov x6, x27 // KernelHeight counter (x6 is scratch in this path) + + cbz x6, .Lconv_done_mid4_\@ + cbz x28, .Lconv_done_mid4_\@ + + .Lkh_loop_mid4_\@: + mov x12, x8 // Input pointer for width (output 0) + mov x16, x28 // KernelWidth counter + + .Lkw_loop_mid4_\@: + add x9, x12, x22 // Output 1 input pointer + add x10, x9, x22 // Output 2 input pointer + add x11, x10, x22 // Output 3 input pointer + + // Load lanes 0..7 for all outputs. + ldp q20, q21, [x12, #0] + ldp q22, q23, [x9, #0] + ldp q24, q25, [x10, #0] + ldp q26, q27, [x11, #0] + + PROCESS_LANE_4OUT 0, 32, v20, v22, v24, v26, 0 + PROCESS_LANE_4OUT 64, 96, v20, v22, v24, v26, 1 + PROCESS_LANE_4OUT 128, 160, v20, v22, v24, v26, 2 + PROCESS_LANE_4OUT 192, 224, v20, v22, v24, v26, 3 + PROCESS_LANE_4OUT 256, 288, v21, v23, v25, v27, 0 + PROCESS_LANE_4OUT 320, 352, v21, v23, v25, v27, 1 + PROCESS_LANE_4OUT 384, 416, v21, v23, v25, v27, 2 + PROCESS_LANE_4OUT 448, 480, v21, v23, v25, v27, 3 + + // Load lanes 8..15 for all outputs. + ldp q20, q21, [x12, #32] + ldp q22, q23, [x9, #32] + ldp q24, q25, [x10, #32] + ldp q26, q27, [x11, #32] + + PROCESS_LANE_4OUT 512, 544, v20, v22, v24, v26, 0 + PROCESS_LANE_4OUT 576, 608, v20, v22, v24, v26, 1 + PROCESS_LANE_4OUT 640, 672, v20, v22, v24, v26, 2 + PROCESS_LANE_4OUT 704, 736, v20, v22, v24, v26, 3 + PROCESS_LANE_4OUT 768, 800, v21, v23, v25, v27, 0 + PROCESS_LANE_4OUT 832, 864, v21, v23, v25, v27, 1 + PROCESS_LANE_4OUT 896, 928, v21, v23, v25, v27, 2 + PROCESS_LANE_4OUT 960, 992, v21, v23, v25, v27, 3 + + add x7, x7, #1024 + add x12, x12, x23 + subs x16, x16, #1 + b.ne .Lkw_loop_mid4_\@ + + add x8, x8, x15 + subs x6, x6, #1 + b.ne .Lkh_loop_mid4_\@ + + .Lconv_done_mid4_\@: + .endm + +// +// void +// MlasConvNchwcFloatKernelNeonAsm( +// const float* Input, +// const float* Filter, +// float* Output, +// size_t StrideWidth, +// size_t DilationWidth, +// size_t FilterCount, +// size_t InputStride, +// size_t FilterStride, +// size_t OutputStride, +// size_t KernelHeight, +// size_t KernelWidth, +// const float* InputBase, +// size_t InputWidth, +// size_t DilatedInputWidth, +// size_t OutputCountLeftPad, +// size_t OutputCount, +// size_t OutputCountRightPad, +// const float* Bias, +// unsigned KernelFlags +// ); +// + + .global MlasConvNchwcFloatKernelNeonAsm + .type MlasConvNchwcFloatKernelNeonAsm, %function + +MlasConvNchwcFloatKernelNeonAsm: + + // Preserve the incoming stack pointer to access stack-passed arguments. + mov x9, sp + + // Prologue and callee-saved register spill. + stp x29, x30, [sp, #-128]! + mov x29, sp + stp x19, x20, [sp, #16] + stp x21, x22, [sp, #32] + stp x23, x24, [sp, #48] + stp x25, x26, [sp, #64] + stp x27, x28, [sp, #80] + + // Move register arguments into callee-saved registers. + mov x19, x0 // Input + mov x20, x1 // Filter + mov x21, x2 // Output + mov x22, x3 // StrideWidth (bytes) + mov x23, x4 // DilationWidth (bytes) + mov x24, x5 // FilterCount + mov x25, x7 // FilterStride (bytes) + + // Load stack arguments using the preserved incoming stack pointer. + ldr x10, [x9, #0] // OutputStride (bytes) + ldr x11, [x9, #8] // KernelHeight + ldr x12, [x9, #16] // KernelWidth + ldr x13, [x9, #24] // InputBase + ldr x14, [x9, #32] // InputWidth (bytes) + ldr x15, [x9, #40] // DilatedInputWidth (bytes) + ldr x16, [x9, #48] // OutputCountLeftPad + ldr x17, [x9, #56] // OutputCount + ldr x18, [x9, #64] // OutputCountRightPad + ldr x6, [x9, #72] // Bias + ldr w8, [x9, #80] // KernelFlags + + // Early exit when nothing to compute. + add x0, x16, x17 + add x0, x0, x18 // x0 = TotalOutputCount + cbz x0, .Lepilogue + cbz x24, .Lepilogue + + // Spill the output counts so that x16/x17/x18 can be used as scratch. + str x16, [sp, #96] + str x17, [sp, #104] + str x18, [sp, #112] + + mov x26, x10 // OutputStride (bytes) + mov x27, x11 // KernelHeight + mov x28, x12 // KernelWidth + mov x17, x6 // Bias + mov w18, w8 // KernelFlags + + // Set up a zero vector for ReLU. + movi v31.4s, #0 + + // x1 = current input base for the output index. + // x2 = output offset in bytes for the output index. + mov x1, x19 + mov x2, xzr + + // Fast path when no post-processing flags are enabled. This removes + // repeated flag checks and branches from the steady-state loops. + tst w18, #7 + b.eq .Lkernel_flags0 + + // Process the left padded output region with bounds checks. + ldr x0, [sp, #96] + cbz x0, .Loutput_mid_begin + +.Loutput_left_loop: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + mov x6, x17 + +.Lfilter_left_loop: + + // Clear accumulators. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + + // Convolution loop with bounds checks. + CONV_LOOP_PAD + + // Compute the output pointer for this filter set and output index. + add x12, x5, x2 + + // Conditionally accumulate the existing output. + tst w18, #1 + b.eq .Lskip_accumulate + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + +.Lskip_accumulate: + + // Conditionally add bias. + tst w18, #2 + b.eq .Lskip_bias + ldp q16, q17, [x6, #0] + ldp q18, q19, [x6, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + +.Lskip_bias: + + // Conditionally apply ReLU. + tst w18, #4 + b.eq .Lskip_relu + fmax v0.4s, v0.4s, v31.4s + fmax v1.4s, v1.4s, v31.4s + fmax v2.4s, v2.4s, v31.4s + fmax v3.4s, v3.4s, v31.4s + +.Lskip_relu: + + // Store the result. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + + // Advance filter/output/bias pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + add x6, x6, #64 + subs x3, x3, #1 + b.ne .Lfilter_left_loop + + // Advance to the next output index. + add x1, x1, x22 + add x2, x2, #64 + subs x0, x0, #1 + b.ne .Loutput_left_loop + + // Process the middle output region without bounds checks. +.Loutput_mid_begin: + ldr x0, [sp, #104] + cbz x0, .Loutput_right_begin + + // Compute the number of output triads and spill the remaining outputs. + mov x12, x0 + // Use a reciprocal multiply to divide by 3 and avoid the expensive UDIV. + movz x16, #0xaaab + movk x16, #0xaaaa, lsl #16 + movk x16, #0xaaaa, lsl #32 + movk x16, #0xaaaa, lsl #48 + umulh x0, x0, x16 + lsr x0, x0, #1 + // Compute remainder = OutputCount - triads * 3. + add x16, x0, x0, lsl #1 + sub x16, x12, x16 + str x16, [sp, #120] + cbz x0, .Loutput_mid_pair_begin + +.Loutput_mid_triad_loop: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + mov x6, x17 + +.Lfilter_mid_triad_loop: + + // Clear accumulators for three output points. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + eor v4.16b, v4.16b, v4.16b + eor v5.16b, v5.16b, v5.16b + eor v6.16b, v6.16b, v6.16b + eor v7.16b, v7.16b, v7.16b + eor v8.16b, v8.16b, v8.16b + eor v9.16b, v9.16b, v9.16b + eor v10.16b, v10.16b, v10.16b + eor v11.16b, v11.16b, v11.16b + + // Convolution loop without bounds checks computing three outputs. + CONV_LOOP_MID_3OUT + + // Compute the output pointers for the three output points. + add x12, x5, x2 + add x11, x12, #64 + add x10, x12, #128 + + // Conditionally accumulate the existing output. + tst w18, #1 + b.eq .Lskip_accumulate_triad + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + ldp q20, q21, [x11, #0] + ldp q22, q23, [x11, #32] + ldp q24, q25, [x10, #0] + ldp q26, q27, [x10, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + fadd v4.4s, v4.4s, v20.4s + fadd v5.4s, v5.4s, v21.4s + fadd v6.4s, v6.4s, v22.4s + fadd v7.4s, v7.4s, v23.4s + fadd v8.4s, v8.4s, v24.4s + fadd v9.4s, v9.4s, v25.4s + fadd v10.4s, v10.4s, v26.4s + fadd v11.4s, v11.4s, v27.4s + +.Lskip_accumulate_triad: + + // Conditionally add bias. + tst w18, #2 + b.eq .Lskip_bias_triad + ldp q16, q17, [x6, #0] + ldp q18, q19, [x6, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v18.4s + fadd v7.4s, v7.4s, v19.4s + fadd v8.4s, v8.4s, v16.4s + fadd v9.4s, v9.4s, v17.4s + fadd v10.4s, v10.4s, v18.4s + fadd v11.4s, v11.4s, v19.4s + +.Lskip_bias_triad: + + // Conditionally apply ReLU. + tst w18, #4 + b.eq .Lskip_relu_triad + fmax v0.4s, v0.4s, v31.4s + fmax v1.4s, v1.4s, v31.4s + fmax v2.4s, v2.4s, v31.4s + fmax v3.4s, v3.4s, v31.4s + fmax v4.4s, v4.4s, v31.4s + fmax v5.4s, v5.4s, v31.4s + fmax v6.4s, v6.4s, v31.4s + fmax v7.4s, v7.4s, v31.4s + fmax v8.4s, v8.4s, v31.4s + fmax v9.4s, v9.4s, v31.4s + fmax v10.4s, v10.4s, v31.4s + fmax v11.4s, v11.4s, v31.4s + +.Lskip_relu_triad: + + // Store the results for the three output points. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + stp q4, q5, [x11, #0] + stp q6, q7, [x11, #32] + stp q8, q9, [x10, #0] + stp q10, q11, [x10, #32] + + // Advance filter/output/bias pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + add x6, x6, #64 + subs x3, x3, #1 + b.ne .Lfilter_mid_triad_loop + + // Advance to the next three output indices. + add x1, x1, x22, lsl #1 + add x1, x1, x22 + add x2, x2, #192 + subs x0, x0, #1 + b.ne .Loutput_mid_triad_loop + +.Loutput_mid_pair_begin: + ldr x0, [sp, #120] + and x16, x0, #1 + str x16, [sp, #120] + lsr x0, x0, #1 + cbz x0, .Loutput_mid_single_begin + +.Loutput_mid_pair_loop: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + mov x6, x17 + +.Lfilter_mid_pair_loop: + + // Clear accumulators for both output points. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + + // Convolution loop without bounds checks computing two outputs. + CONV_LOOP_MID_2OUT + + // Compute the output pointers for the two output points. + add x12, x5, x2 + add x11, x12, #64 + + // Conditionally accumulate the existing output. + tst w18, #1 + b.eq .Lskip_accumulate_pair + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + ldp q24, q25, [x11, #0] + ldp q26, q27, [x11, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + fadd v20.4s, v20.4s, v24.4s + fadd v21.4s, v21.4s, v25.4s + fadd v22.4s, v22.4s, v26.4s + fadd v23.4s, v23.4s, v27.4s + +.Lskip_accumulate_pair: + + // Conditionally add bias. + tst w18, #2 + b.eq .Lskip_bias_pair + ldp q16, q17, [x6, #0] + ldp q18, q19, [x6, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + fadd v20.4s, v20.4s, v16.4s + fadd v21.4s, v21.4s, v17.4s + fadd v22.4s, v22.4s, v18.4s + fadd v23.4s, v23.4s, v19.4s + +.Lskip_bias_pair: + + // Conditionally apply ReLU. + tst w18, #4 + b.eq .Lskip_relu_pair + fmax v0.4s, v0.4s, v31.4s + fmax v1.4s, v1.4s, v31.4s + fmax v2.4s, v2.4s, v31.4s + fmax v3.4s, v3.4s, v31.4s + fmax v20.4s, v20.4s, v31.4s + fmax v21.4s, v21.4s, v31.4s + fmax v22.4s, v22.4s, v31.4s + fmax v23.4s, v23.4s, v31.4s + +.Lskip_relu_pair: + + // Store the results for the two output points. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + stp q20, q21, [x11, #0] + stp q22, q23, [x11, #32] + + // Advance filter/output/bias pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + add x6, x6, #64 + subs x3, x3, #1 + b.ne .Lfilter_mid_pair_loop + + // Advance to the next two output indices. + add x1, x1, x22, lsl #1 + add x2, x2, #128 + subs x0, x0, #1 + b.ne .Loutput_mid_pair_loop + +.Loutput_mid_single_begin: + ldr x0, [sp, #120] + cbz x0, .Loutput_right_begin + +.Loutput_mid_loop: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + mov x6, x17 + +.Lfilter_mid_loop: + + // Clear accumulators. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + + // Convolution loop without bounds checks. + CONV_LOOP_MID + + // Compute the output pointer for this filter set and output index. + add x12, x5, x2 + + // Conditionally accumulate the existing output. + tst w18, #1 + b.eq .Lskip_accumulate_mid + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + +.Lskip_accumulate_mid: + + // Conditionally add bias. + tst w18, #2 + b.eq .Lskip_bias_mid + ldp q16, q17, [x6, #0] + ldp q18, q19, [x6, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + +.Lskip_bias_mid: + + // Conditionally apply ReLU. + tst w18, #4 + b.eq .Lskip_relu_mid + fmax v0.4s, v0.4s, v31.4s + fmax v1.4s, v1.4s, v31.4s + fmax v2.4s, v2.4s, v31.4s + fmax v3.4s, v3.4s, v31.4s + +.Lskip_relu_mid: + + // Store the result. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + + // Advance filter/output/bias pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + add x6, x6, #64 + subs x3, x3, #1 + b.ne .Lfilter_mid_loop + + // Advance to the next output index. + add x1, x1, x22 + add x2, x2, #64 + subs x0, x0, #1 + b.ne .Loutput_mid_loop + + // Process the right padded output region with bounds checks. +.Loutput_right_begin: + ldr x0, [sp, #112] + cbz x0, .Lepilogue + +.Loutput_right_loop: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + mov x6, x17 + +.Lfilter_right_loop: + + // Clear accumulators. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + + // Convolution loop with bounds checks. + CONV_LOOP_PAD + + // Compute the output pointer for this filter set and output index. + add x12, x5, x2 + + // Conditionally accumulate the existing output. + tst w18, #1 + b.eq .Lskip_accumulate_right + ldp q16, q17, [x12, #0] + ldp q18, q19, [x12, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + +.Lskip_accumulate_right: + + // Conditionally add bias. + tst w18, #2 + b.eq .Lskip_bias_right + ldp q16, q17, [x6, #0] + ldp q18, q19, [x6, #32] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s + +.Lskip_bias_right: + + // Conditionally apply ReLU. + tst w18, #4 + b.eq .Lskip_relu_right + fmax v0.4s, v0.4s, v31.4s + fmax v1.4s, v1.4s, v31.4s + fmax v2.4s, v2.4s, v31.4s + fmax v3.4s, v3.4s, v31.4s + +.Lskip_relu_right: + + // Store the result. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + + // Advance filter/output/bias pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + add x6, x6, #64 + subs x3, x3, #1 + b.ne .Lfilter_right_loop + + // Advance to the next output index. + add x1, x1, x22 + add x2, x2, #64 + subs x0, x0, #1 + b.ne .Loutput_right_loop + + // Skip the flag-free fast path section. + b .Lepilogue + +.Lkernel_flags0: + + // KernelFlags == 0 fast path: no accumulation, bias, or activation. + + // Process the left padded output region with bounds checks. + ldr x0, [sp, #96] + cbz x0, .Loutput_mid_begin_flags0 + +.Loutput_left_loop_flags0: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + +.Lfilter_left_loop_flags0: + + // Clear accumulators. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + + // Convolution loop with bounds checks. + CONV_LOOP_PAD + + // Compute the output pointer for this filter set and output index. + add x12, x5, x2 + + // Store the result. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + + // Advance filter/output pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + subs x3, x3, #1 + b.ne .Lfilter_left_loop_flags0 + + // Advance to the next output index. + add x1, x1, x22 + add x2, x2, #64 + subs x0, x0, #1 + b.ne .Loutput_left_loop_flags0 + + // Process the middle output region without bounds checks. +.Loutput_mid_begin_flags0: + ldr x0, [sp, #104] + cbz x0, .Loutput_right_begin_flags0 + + // Process four outputs at a time to amortize the filter loads. + and x16, x0, #3 // Remainder outputs after quads. + str x16, [sp, #120] + lsr x0, x0, #2 // Quad output count. + cbz x0, .Loutput_mid_remainder_begin_flags0 + +.Loutput_mid_quad_loop_flags0: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + +.Lfilter_mid_quad_loop_flags0: + + // Clear accumulators for four output points. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + eor v4.16b, v4.16b, v4.16b + eor v5.16b, v5.16b, v5.16b + eor v6.16b, v6.16b, v6.16b + eor v7.16b, v7.16b, v7.16b + eor v8.16b, v8.16b, v8.16b + eor v9.16b, v9.16b, v9.16b + eor v10.16b, v10.16b, v10.16b + eor v11.16b, v11.16b, v11.16b + eor v12.16b, v12.16b, v12.16b + eor v13.16b, v13.16b, v13.16b + eor v14.16b, v14.16b, v14.16b + eor v15.16b, v15.16b, v15.16b + + // Convolution loop without bounds checks computing four outputs. + CONV_LOOP_MID_4OUT + + // Compute the output pointers for the four output points. + add x12, x5, x2 + add x11, x12, #64 + add x10, x12, #128 + add x9, x12, #192 + + // Store the results for the four output points. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + stp q4, q5, [x11, #0] + stp q6, q7, [x11, #32] + stp q8, q9, [x10, #0] + stp q10, q11, [x10, #32] + stp q12, q13, [x9, #0] + stp q14, q15, [x9, #32] + + // Advance filter/output pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + subs x3, x3, #1 + b.ne .Lfilter_mid_quad_loop_flags0 + + // Advance to the next four output indices. + add x1, x1, x22, lsl #2 + add x2, x2, #256 + subs x0, x0, #1 + b.ne .Loutput_mid_quad_loop_flags0 + +.Loutput_mid_remainder_begin_flags0: + ldr x0, [sp, #120] + cbz x0, .Loutput_right_begin_flags0 + cmp x0, #3 + b.ne .Loutput_mid_remainder_not3_flags0 + // Exactly three outputs remain. + str xzr, [sp, #120] + mov x0, #1 + b .Loutput_mid_triad_loop_flags0 + +.Loutput_mid_remainder_not3_flags0: + cmp x0, #2 + b.ne .Loutput_mid_remainder_single_flags0 + // Exactly two outputs remain. + str xzr, [sp, #120] + mov x0, #1 + b .Loutput_mid_pair_loop_flags0 + +.Loutput_mid_remainder_single_flags0: + // Exactly one output remains. + mov x0, #1 + b .Loutput_mid_loop_flags0 + +.Loutput_mid_triad_loop_flags0: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + +.Lfilter_mid_triad_loop_flags0: + + // Clear accumulators for three output points. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + eor v4.16b, v4.16b, v4.16b + eor v5.16b, v5.16b, v5.16b + eor v6.16b, v6.16b, v6.16b + eor v7.16b, v7.16b, v7.16b + eor v8.16b, v8.16b, v8.16b + eor v9.16b, v9.16b, v9.16b + eor v10.16b, v10.16b, v10.16b + eor v11.16b, v11.16b, v11.16b + + // Convolution loop without bounds checks computing three outputs. + CONV_LOOP_MID_3OUT + + // Compute the output pointers for the three output points. + add x12, x5, x2 + add x11, x12, #64 + add x10, x12, #128 + + // Store the results for the three output points. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + stp q4, q5, [x11, #0] + stp q6, q7, [x11, #32] + stp q8, q9, [x10, #0] + stp q10, q11, [x10, #32] + + // Advance filter/output pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + subs x3, x3, #1 + b.ne .Lfilter_mid_triad_loop_flags0 + + // Advance to the next three output indices. + add x1, x1, x22, lsl #1 + add x1, x1, x22 + add x2, x2, #192 + subs x0, x0, #1 + b.ne .Loutput_mid_triad_loop_flags0 + +.Loutput_mid_pair_begin_flags0: + ldr x0, [sp, #120] + and x16, x0, #1 + str x16, [sp, #120] + lsr x0, x0, #1 + cbz x0, .Loutput_mid_single_begin_flags0 + +.Loutput_mid_pair_loop_flags0: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + +.Lfilter_mid_pair_loop_flags0: + + // Clear accumulators for both output points. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + + // Convolution loop without bounds checks computing two outputs. + CONV_LOOP_MID_2OUT + + // Compute the output pointers for the two output points. + add x12, x5, x2 + add x11, x12, #64 + + // Store the results for the two output points. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + stp q20, q21, [x11, #0] + stp q22, q23, [x11, #32] + + // Advance filter/output pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + subs x3, x3, #1 + b.ne .Lfilter_mid_pair_loop_flags0 + + // Advance to the next two output indices. + add x1, x1, x22, lsl #1 + add x2, x2, #128 + subs x0, x0, #1 + b.ne .Loutput_mid_pair_loop_flags0 + +.Loutput_mid_single_begin_flags0: + ldr x0, [sp, #120] + cbz x0, .Loutput_right_begin_flags0 + +.Loutput_mid_loop_flags0: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + +.Lfilter_mid_loop_flags0: + + // Clear accumulators. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + + // Convolution loop without bounds checks. + CONV_LOOP_MID + + // Compute the output pointer for this filter set and output index. + add x12, x5, x2 + + // Store the result. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + + // Advance filter/output pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + subs x3, x3, #1 + b.ne .Lfilter_mid_loop_flags0 + + // Advance to the next output index. + add x1, x1, x22 + add x2, x2, #64 + subs x0, x0, #1 + b.ne .Loutput_mid_loop_flags0 + + // Process the right padded output region with bounds checks. +.Loutput_right_begin_flags0: + ldr x0, [sp, #112] + cbz x0, .Lepilogue + +.Loutput_right_loop_flags0: + + // Initialize per-filter-set pointers and loop counter. + mov x3, x24 + mov x4, x20 + mov x5, x21 + +.Lfilter_right_loop_flags0: + + // Clear accumulators. + eor v0.16b, v0.16b, v0.16b + eor v1.16b, v1.16b, v1.16b + eor v2.16b, v2.16b, v2.16b + eor v3.16b, v3.16b, v3.16b + + // Convolution loop with bounds checks. + CONV_LOOP_PAD + + // Compute the output pointer for this filter set and output index. + add x12, x5, x2 + + // Store the result. + stp q0, q1, [x12, #0] + stp q2, q3, [x12, #32] + + // Advance filter/output pointers for the next filter set block. + add x4, x4, x25 + add x5, x5, x26 + subs x3, x3, #1 + b.ne .Lfilter_right_loop_flags0 + + // Advance to the next output index. + add x1, x1, x22 + add x2, x2, #64 + subs x0, x0, #1 + b.ne .Loutput_right_loop_flags0 + + b .Lepilogue + +.Lepilogue: + + // Epilogue and callee-saved register restore. + ldp x27, x28, [sp, #80] + ldp x25, x26, [sp, #64] + ldp x23, x24, [sp, #48] + ldp x21, x22, [sp, #32] + ldp x19, x20, [sp, #16] + ldp x29, x30, [sp], #128 + ret + + .size MlasConvNchwcFloatKernelNeonAsm, .-MlasConvNchwcFloatKernelNeonAsm + +#endif // defined(__aarch64__) && !defined(_WIN32) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 954849fe90049..f59ec4858d0d7 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1008,6 +1008,10 @@ extern "C" { MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeonAsm; #endif MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; +#if !defined(_WIN32) + // AArch64 assembly micro-kernel for direct NCHWc convolution + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeonAsm; +#endif // Intrinsics kernel for depthwise NCHWc convolution MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; #if !defined(_WIN32) diff --git a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp index 59c30a80b53af..6ca2398f91503 100644 --- a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp @@ -263,6 +263,29 @@ void unsigned KernelFlags ) { +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32) + MlasConvNchwcFloatKernelNeonAsm( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +#else MlasConvFloatKernelNeonImpl( Input, Filter, @@ -284,6 +307,7 @@ void Bias, KernelFlags ); +#endif } // From 5e0dc2dbb6c1f8574f4315f7accbb63d2558db31 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Fri, 20 Mar 2026 22:37:59 +0000 Subject: [PATCH 2/6] Mach-O / Apple portability and AArch64 ABI correctness fix Signed-off-by: Milos Puzovic --- .../mlas/lib/aarch64/SconvNchwcKernelNeon.S | 56 ++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S index 0ee4ee6ab1bbc..18fdab92b4db9 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S @@ -15,6 +15,8 @@ Abstract: #if defined(__aarch64__) && !defined(_WIN32) +#include "asmmacro.h" + .text .align 2 @@ -409,22 +411,24 @@ Abstract: // ); // - .global MlasConvNchwcFloatKernelNeonAsm - .type MlasConvNchwcFloatKernelNeonAsm, %function - -MlasConvNchwcFloatKernelNeonAsm: + FUNCTION_ENTRY MlasConvNchwcFloatKernelNeonAsm // Preserve the incoming stack pointer to access stack-passed arguments. mov x9, sp // Prologue and callee-saved register spill. - stp x29, x30, [sp, #-128]! + // Save callee-saved SIMD registers v8-v15 per AArch64 ABI. + stp x29, x30, [sp, #-256]! mov x29, sp stp x19, x20, [sp, #16] stp x21, x22, [sp, #32] stp x23, x24, [sp, #48] stp x25, x26, [sp, #64] stp x27, x28, [sp, #80] + stp q8, q9, [sp, #96] + stp q10, q11, [sp, #128] + stp q12, q13, [sp, #160] + stp q14, q15, [sp, #192] // Move register arguments into callee-saved registers. mov x19, x0 // Input @@ -455,9 +459,9 @@ MlasConvNchwcFloatKernelNeonAsm: cbz x24, .Lepilogue // Spill the output counts so that x16/x17/x18 can be used as scratch. - str x16, [sp, #96] - str x17, [sp, #104] - str x18, [sp, #112] + str x16, [sp, #224] + str x17, [sp, #232] + str x18, [sp, #240] mov x26, x10 // OutputStride (bytes) mov x27, x11 // KernelHeight @@ -479,7 +483,7 @@ MlasConvNchwcFloatKernelNeonAsm: b.eq .Lkernel_flags0 // Process the left padded output region with bounds checks. - ldr x0, [sp, #96] + ldr x0, [sp, #224] cbz x0, .Loutput_mid_begin .Loutput_left_loop: @@ -557,7 +561,7 @@ MlasConvNchwcFloatKernelNeonAsm: // Process the middle output region without bounds checks. .Loutput_mid_begin: - ldr x0, [sp, #104] + ldr x0, [sp, #232] cbz x0, .Loutput_right_begin // Compute the number of output triads and spill the remaining outputs. @@ -572,7 +576,7 @@ MlasConvNchwcFloatKernelNeonAsm: // Compute remainder = OutputCount - triads * 3. add x16, x0, x0, lsl #1 sub x16, x12, x16 - str x16, [sp, #120] + str x16, [sp, #248] cbz x0, .Loutput_mid_pair_begin .Loutput_mid_triad_loop: @@ -692,9 +696,9 @@ MlasConvNchwcFloatKernelNeonAsm: b.ne .Loutput_mid_triad_loop .Loutput_mid_pair_begin: - ldr x0, [sp, #120] + ldr x0, [sp, #248] and x16, x0, #1 - str x16, [sp, #120] + str x16, [sp, #248] lsr x0, x0, #1 cbz x0, .Loutput_mid_single_begin @@ -793,7 +797,7 @@ MlasConvNchwcFloatKernelNeonAsm: b.ne .Loutput_mid_pair_loop .Loutput_mid_single_begin: - ldr x0, [sp, #120] + ldr x0, [sp, #248] cbz x0, .Loutput_right_begin .Loutput_mid_loop: @@ -871,7 +875,7 @@ MlasConvNchwcFloatKernelNeonAsm: // Process the right padded output region with bounds checks. .Loutput_right_begin: - ldr x0, [sp, #112] + ldr x0, [sp, #240] cbz x0, .Lepilogue .Loutput_right_loop: @@ -955,7 +959,7 @@ MlasConvNchwcFloatKernelNeonAsm: // KernelFlags == 0 fast path: no accumulation, bias, or activation. // Process the left padded output region with bounds checks. - ldr x0, [sp, #96] + ldr x0, [sp, #224] cbz x0, .Loutput_mid_begin_flags0 .Loutput_left_loop_flags0: @@ -997,12 +1001,12 @@ MlasConvNchwcFloatKernelNeonAsm: // Process the middle output region without bounds checks. .Loutput_mid_begin_flags0: - ldr x0, [sp, #104] + ldr x0, [sp, #232] cbz x0, .Loutput_right_begin_flags0 // Process four outputs at a time to amortize the filter loads. and x16, x0, #3 // Remainder outputs after quads. - str x16, [sp, #120] + str x16, [sp, #248] lsr x0, x0, #2 // Quad output count. cbz x0, .Loutput_mid_remainder_begin_flags0 @@ -1065,7 +1069,7 @@ MlasConvNchwcFloatKernelNeonAsm: b.ne .Loutput_mid_quad_loop_flags0 .Loutput_mid_remainder_begin_flags0: - ldr x0, [sp, #120] + ldr x0, [sp, #248] cbz x0, .Loutput_right_begin_flags0 cmp x0, #3 b.ne .Loutput_mid_remainder_not3_flags0 @@ -1142,7 +1146,7 @@ MlasConvNchwcFloatKernelNeonAsm: .Loutput_mid_pair_begin_flags0: ldr x0, [sp, #120] and x16, x0, #1 - str x16, [sp, #120] + str x16, [sp, #248] lsr x0, x0, #1 cbz x0, .Loutput_mid_single_begin_flags0 @@ -1191,7 +1195,7 @@ MlasConvNchwcFloatKernelNeonAsm: b.ne .Loutput_mid_pair_loop_flags0 .Loutput_mid_single_begin_flags0: - ldr x0, [sp, #120] + ldr x0, [sp, #248] cbz x0, .Loutput_right_begin_flags0 .Loutput_mid_loop_flags0: @@ -1233,7 +1237,7 @@ MlasConvNchwcFloatKernelNeonAsm: // Process the right padded output region with bounds checks. .Loutput_right_begin_flags0: - ldr x0, [sp, #112] + ldr x0, [sp, #240] cbz x0, .Lepilogue .Loutput_right_loop_flags0: @@ -1278,14 +1282,16 @@ MlasConvNchwcFloatKernelNeonAsm: .Lepilogue: // Epilogue and callee-saved register restore. + ldp q14, q15, [sp, #192] + ldp q12, q13, [sp, #160] + ldp q10, q11, [sp, #128] + ldp q8, q9, [sp, #96] ldp x27, x28, [sp, #80] ldp x25, x26, [sp, #64] ldp x23, x24, [sp, #48] ldp x21, x22, [sp, #32] ldp x19, x20, [sp, #16] - ldp x29, x30, [sp], #128 + ldp x29, x30, [sp], #256 ret - .size MlasConvNchwcFloatKernelNeonAsm, .-MlasConvNchwcFloatKernelNeonAsm - #endif // defined(__aarch64__) && !defined(_WIN32) From 0cdec808b6b15f176ac569597d2796ea4804b452 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Mon, 23 Mar 2026 09:35:38 +0000 Subject: [PATCH 3/6] Added benchmark for micro-kernel MlasConvNchwcFloatKernelNeon Signed-off-by: Milos Puzovic --- .../test/mlas/bench/bench_sconv_nchwc.cpp | 243 ++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp diff --git a/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp b/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp new file mode 100644 index 0000000000000..a472143e4840e --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "mlas.h" +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/sconv_nchwc_kernel_neon.h" + +#include +#include +#include +#include + +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32) + +constexpr size_t NchwcBlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +static std::vector ArgNamesForDirectNchwc() { + return {"IC", "OC", "IH", "IW", "KH", "KW", "PT", "PL", "PB", "PR", "S", "D"}; +} + +static size_t ComputeOutputDim(size_t input, size_t kernel, size_t stride, size_t dilation, size_t pad_before, + size_t pad_after) { + const size_t dilated = (kernel - 1) * dilation + 1; + if (input + pad_before + pad_after < dilated) { + throw std::invalid_argument("Invalid shape: input smaller than dilated kernel"); + } + return (input + pad_before + pad_after - dilated) / stride + 1; +} + +static size_t DivideRoundUp(size_t numerator, size_t denominator) { + return (numerator + denominator - 1) / denominator; +} + +static void RunDirectNchwcKernel(const size_t input_channels, + const size_t output_channels, + const size_t input_height, + const size_t input_width, + const size_t kernel_height, + const size_t kernel_width, + const size_t pad_top, + const size_t pad_left, + const size_t pad_bottom, + const size_t pad_right, + const size_t stride, + const size_t dilation, + MLAS_CONV_FLOAT_KERNEL* Kernel, + const float* input, + const float* filter, + const float* bias, + float* output) { + // This routine drives the direct NCHWc micro-kernel exactly like the production + // driver path: one output row and one input-channel block contribution per call. + // The micro-kernel itself does not iterate over full output height or all IC blocks. + const size_t output_height = ComputeOutputDim(input_height, kernel_height, stride, dilation, pad_top, pad_bottom); + const size_t output_width = ComputeOutputDim(input_width, kernel_width, stride, dilation, pad_left, pad_right); + const size_t kernel_size = kernel_height * kernel_width; + + const size_t input_channel_blocks = input_channels / NchwcBlockSize; + const size_t output_channel_blocks = output_channels / NchwcBlockSize; + + const size_t stride_width_bytes = NchwcBlockSize * stride * sizeof(float); + const size_t dilation_width_bytes = NchwcBlockSize * dilation * sizeof(float); + const size_t filter_stride_bytes = NchwcBlockSize * input_channels * kernel_size * sizeof(float); + const size_t output_stride_bytes = NchwcBlockSize * output_height * output_width * sizeof(float); + const size_t input_width_bytes = NchwcBlockSize * input_width * sizeof(float); + const size_t dilated_input_width_bytes = NchwcBlockSize * dilation * input_width * sizeof(float); + const size_t input_stride_bytes = dilated_input_width_bytes - kernel_width * dilation_width_bytes; + const size_t dilated_kernel_width = (kernel_width - 1) * dilation + 1; + + const size_t output_count_left_pad = std::min(output_width, DivideRoundUp(pad_left, stride)); + size_t output_count_right_pad = 0; + while (output_count_right_pad < (output_width - output_count_left_pad)) { + const size_t ox = output_width - output_count_right_pad - 1; + const ptrdiff_t input_x = static_cast(ox * stride) - static_cast(pad_left); + if (input_x + static_cast(dilated_kernel_width) <= static_cast(input_width)) { + break; + } + output_count_right_pad++; + } + const size_t output_count = output_width - output_count_left_pad - output_count_right_pad; + + // Outer IC-block loop is required to accumulate all channel-block contributions: + // first block initializes with bias, remaining blocks accumulate into output. + for (size_t icb = 0; icb < input_channel_blocks; ++icb) { + const unsigned kernel_flags = + (icb == 0) ? MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION : MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT; + + const float* ic_block_input = input + icb * input_height * input_width * NchwcBlockSize; + const float* ic_block_filter = filter + icb * kernel_size * NchwcBlockSize * NchwcBlockSize; + + // Outer OH loop is required because the micro-kernel processes one output row. + for (size_t oh = 0; oh < output_height; ++oh) { + const ptrdiff_t output_origin_h = static_cast(oh * stride) - static_cast(pad_top); + const size_t kh_start = output_origin_h < 0 ? DivideRoundUp(static_cast(-output_origin_h), dilation) : 0; + + size_t kh_end = kernel_height; + const ptrdiff_t input_h_limit = static_cast(input_height); + if (output_origin_h + static_cast((kernel_height - 1) * dilation) >= input_h_limit) { + if (output_origin_h >= input_h_limit) { + kh_end = 0; + } else { + const ptrdiff_t span = input_h_limit - 1 - output_origin_h; + kh_end = static_cast(span / static_cast(dilation)) + 1; + } + } + if (kh_start >= kh_end) { + continue; + } + + const size_t effective_kernel_height = kh_end - kh_start; + const ptrdiff_t input_h_base = output_origin_h + static_cast(kh_start * dilation); + + const float* kernel_input_row = + ic_block_input + static_cast(NchwcBlockSize) * + (input_h_base * static_cast(input_width) - static_cast(pad_left)); + const float* input_base_row = + ic_block_input + static_cast(NchwcBlockSize) * (input_h_base * static_cast(input_width)); + const float* kernel_filter = ic_block_filter + kh_start * kernel_width * NchwcBlockSize * NchwcBlockSize; + float* output_row = output + oh * output_width * NchwcBlockSize; + + Kernel(kernel_input_row, + kernel_filter, + output_row, + stride_width_bytes, + dilation_width_bytes, + output_channel_blocks, + input_stride_bytes, + filter_stride_bytes, + output_stride_bytes, + effective_kernel_height, + kernel_width, + input_base_row, + input_width_bytes, + dilated_input_width_bytes, + output_count_left_pad, + output_count, + output_count_right_pad, + bias, + kernel_flags); + } + } +} + +static void BenchDirectNchwc(benchmark::State& state) { + // It benchmarks the direct NCHWc kernel path only (not full graph/runtime overhead). + // Included: kernel math, row/block driving, and padding edge handling in the driver loop. + // Excluded: threadpool scheduling, graph transforms, model execution, and memory allocator costs. + const size_t input_channels = static_cast(state.range(0)); + const size_t output_channels = static_cast(state.range(1)); + const size_t input_height = static_cast(state.range(2)); + const size_t input_width = static_cast(state.range(3)); + const size_t kernel_height = static_cast(state.range(4)); + const size_t kernel_width = static_cast(state.range(5)); + const size_t pad_top = static_cast(state.range(6)); + const size_t pad_left = static_cast(state.range(7)); + const size_t pad_bottom = static_cast(state.range(8)); + const size_t pad_right = static_cast(state.range(9)); + const size_t stride = static_cast(state.range(10)); + const size_t dilation = static_cast(state.range(11)); + + if (input_channels == 0 || output_channels == 0 || input_height == 0 || input_width == 0 || kernel_height == 0 || + kernel_width == 0 || stride == 0 || dilation == 0) { + throw std::invalid_argument("All benchmark parameters must be > 0"); + } + + if (input_channels % NchwcBlockSize != 0 || output_channels % NchwcBlockSize != 0) { + throw std::invalid_argument("IC and OC must be multiples of MLAS NEON NCHWc block size"); + } + + const size_t output_height = ComputeOutputDim(input_height, kernel_height, stride, dilation, pad_top, pad_bottom); + const size_t output_width = ComputeOutputDim(input_width, kernel_width, stride, dilation, pad_left, pad_right); + const size_t kernel_size = kernel_height * kernel_width; + + const size_t input_channel_blocks = input_channels / NchwcBlockSize; + const size_t output_channel_blocks = output_channels / NchwcBlockSize; + + const size_t input_size = input_channel_blocks * input_height * input_width * NchwcBlockSize; + const size_t filter_size = output_channel_blocks * input_channels * kernel_size * NchwcBlockSize; + const size_t output_size = output_channel_blocks * output_height * output_width * NchwcBlockSize; + + auto input = RandomVectorUniform(std::vector{static_cast(input_size)}, -1.0f, 1.0f); + auto filter = RandomVectorUniform(std::vector{static_cast(filter_size)}, -1.0f, 1.0f); + auto bias = RandomVectorUniform(std::vector{static_cast(output_channels)}, -0.5f, 0.5f); + std::vector output(output_size); + + RunDirectNchwcKernel(input_channels, + output_channels, + input_height, + input_width, + kernel_height, + kernel_width, + pad_top, + pad_left, + pad_bottom, + pad_right, + stride, + dilation, + &MlasConvNchwcFloatKernelNeon, + input.data(), + filter.data(), + bias.data(), + output.data()); + + for (auto _ : state) { + RunDirectNchwcKernel(input_channels, + output_channels, + input_height, + input_width, + kernel_height, + kernel_width, + pad_top, + pad_left, + pad_bottom, + pad_right, + stride, + dilation, + &MlasConvNchwcFloatKernelNeon, + input.data(), + filter.data(), + bias.data(), + output.data()); + } +} + +void SCONV_NCHWC_DIRECT(benchmark::State& state, const char* /*dummy*/) { + BenchDirectNchwc(state); +} + +static void DirectNchwcCases(benchmark::internal::Benchmark* b) { + b->ArgNames(ArgNamesForDirectNchwc()); + + // IC, OC, IH, IW, KH, KW, PT, PL, PB, PR, S, D + b->Args({32, 32, 192, 192, 3, 3, 1, 1, 1, 1, 1, 1}); + b->Args({32, 96, 192, 192, 3, 3, 0, 0, 1, 1, 2, 1}); + b->Args({48, 192, 96, 96, 3, 3, 0, 0, 1, 1, 2, 1}); + b->Args({48, 192, 96, 96, 3, 3, 1, 1, 1, 1, 1, 1}); + b->Args({64, 256, 48, 48, 3, 3, 1, 1, 1, 1, 1, 1}); +} + +BENCHMARK_CAPTURE(SCONV_NCHWC_DIRECT, DirectNchwcCases, "")->Apply(DirectNchwcCases)->UseRealTime(); + +#endif // defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) && !defined(_WIN32) From bd07f6322a2a79a283301d28d1a2f1472f6c3ddf Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Mon, 23 Mar 2026 21:00:14 +0000 Subject: [PATCH 4/6] fix NCHWc asm stack clobber and make local labels portable - Fix SconvNchwcKernelNeon.S KernelFlags==0 remainder path temp spill slot (sp+#120 -> sp+#248) to avoid clobbering callee-saved SIMD spills (q8-q15) - Replace \@-based local labels with numeric local labels (90f/91b style) for portable assembly parsing (including macOS toolchains) Signed-off-by: Milos Puzovic --- .../mlas/lib/aarch64/SconvNchwcKernelNeon.S | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S index 18fdab92b4db9..d301c61ffc257 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S @@ -111,18 +111,18 @@ Abstract: mov x9, x13 // Row start pointer mov x10, x27 // KernelHeight counter - cbz x10, .Lconv_done_pad_\@ - cbz x28, .Lconv_done_pad_\@ + cbz x10, 90f + cbz x28, 90f - .Lkh_loop_pad_\@: + 91: mov x12, x8 // Input pointer for width mov x16, x28 // KernelWidth counter - .Lkw_loop_pad_\@: + 92: // Branch if the input pointer lies outside [row_start, row_start + row_width). sub x11, x12, x9 cmp x11, x14 - b.hs .Lkw_skip_pad_\@ + b.hs 93f ldp q4, q5, [x12, #0] ldp q6, q7, [x12, #32] @@ -144,18 +144,18 @@ Abstract: PROCESS_LANE 896, 928, v7, 2 PROCESS_LANE 960, 992, v7, 3 - .Lkw_skip_pad_\@: + 93: add x7, x7, #1024 add x12, x12, x23 subs x16, x16, #1 - b.ne .Lkw_loop_pad_\@ + b.ne 92b add x9, x9, x15 add x8, x8, x15 subs x10, x10, #1 - b.ne .Lkh_loop_pad_\@ + b.ne 91b - .Lconv_done_pad_\@: + 90: .endm // @@ -168,14 +168,14 @@ Abstract: mov x8, x1 // Input row pointer mov x10, x27 // KernelHeight counter - cbz x10, .Lconv_done_mid_\@ - cbz x28, .Lconv_done_mid_\@ + cbz x10, 100f + cbz x28, 100f - .Lkh_loop_mid_\@: + 101: mov x12, x8 // Input pointer for width mov x16, x28 // KernelWidth counter - .Lkw_loop_mid_\@: + 102: ldp q4, q5, [x12, #0] ldp q6, q7, [x12, #32] @@ -199,13 +199,13 @@ Abstract: add x7, x7, #1024 add x12, x12, x23 subs x16, x16, #1 - b.ne .Lkw_loop_mid_\@ + b.ne 102b add x8, x8, x15 subs x10, x10, #1 - b.ne .Lkh_loop_mid_\@ + b.ne 101b - .Lconv_done_mid_\@: + 100: .endm // @@ -218,14 +218,14 @@ Abstract: mov x8, x1 // Input row pointer (output 0) mov x9, x27 // KernelHeight counter - cbz x9, .Lconv_done_mid3_\@ - cbz x28, .Lconv_done_mid3_\@ + cbz x9, 110f + cbz x28, 110f - .Lkh_loop_mid3_\@: + 111: mov x12, x8 // Input pointer for width (output 0) mov x16, x28 // KernelWidth counter - .Lkw_loop_mid3_\@: + 112: add x11, x12, x22 // Output 1 input pointer add x10, x12, x22, lsl #1 // Output 2 input pointer @@ -256,13 +256,13 @@ Abstract: add x7, x7, #1024 add x12, x12, x23 subs x16, x16, #1 - b.ne .Lkw_loop_mid3_\@ + b.ne 112b add x8, x8, x15 subs x9, x9, #1 - b.ne .Lkh_loop_mid3_\@ + b.ne 111b - .Lconv_done_mid3_\@: + 110: .endm // @@ -276,15 +276,15 @@ Abstract: add x9, x1, x22 // Input row pointer (output 1) mov x10, x27 // KernelHeight counter - cbz x10, .Lconv_done_mid2_\@ - cbz x28, .Lconv_done_mid2_\@ + cbz x10, 120f + cbz x28, 120f - .Lkh_loop_mid2_\@: + 121: mov x12, x8 // Input pointer for width (output 0) mov x11, x9 // Input pointer for width (output 1) mov x16, x28 // KernelWidth counter - .Lkw_loop_mid2_\@: + 122: ldp q4, q5, [x12, #0] ldp q6, q7, [x12, #32] ldp q24, q25, [x11, #0] @@ -311,14 +311,14 @@ Abstract: add x12, x12, x23 add x11, x11, x23 subs x16, x16, #1 - b.ne .Lkw_loop_mid2_\@ + b.ne 122b add x8, x8, x15 add x9, x9, x15 subs x10, x10, #1 - b.ne .Lkh_loop_mid2_\@ + b.ne 121b - .Lconv_done_mid2_\@: + 120: .endm // @@ -332,14 +332,14 @@ Abstract: mov x8, x1 // Input row pointer (output 0) mov x6, x27 // KernelHeight counter (x6 is scratch in this path) - cbz x6, .Lconv_done_mid4_\@ - cbz x28, .Lconv_done_mid4_\@ + cbz x6, 130f + cbz x28, 130f - .Lkh_loop_mid4_\@: + 131: mov x12, x8 // Input pointer for width (output 0) mov x16, x28 // KernelWidth counter - .Lkw_loop_mid4_\@: + 132: add x9, x12, x22 // Output 1 input pointer add x10, x9, x22 // Output 2 input pointer add x11, x10, x22 // Output 3 input pointer @@ -377,13 +377,13 @@ Abstract: add x7, x7, #1024 add x12, x12, x23 subs x16, x16, #1 - b.ne .Lkw_loop_mid4_\@ + b.ne 132b add x8, x8, x15 subs x6, x6, #1 - b.ne .Lkh_loop_mid4_\@ + b.ne 131b - .Lconv_done_mid4_\@: + 130: .endm // @@ -1074,7 +1074,7 @@ Abstract: cmp x0, #3 b.ne .Loutput_mid_remainder_not3_flags0 // Exactly three outputs remain. - str xzr, [sp, #120] + str xzr, [sp, #248] mov x0, #1 b .Loutput_mid_triad_loop_flags0 @@ -1082,7 +1082,7 @@ Abstract: cmp x0, #2 b.ne .Loutput_mid_remainder_single_flags0 // Exactly two outputs remain. - str xzr, [sp, #120] + str xzr, [sp, #248] mov x0, #1 b .Loutput_mid_pair_loop_flags0 @@ -1144,7 +1144,7 @@ Abstract: b.ne .Loutput_mid_triad_loop_flags0 .Loutput_mid_pair_begin_flags0: - ldr x0, [sp, #120] + ldr x0, [sp, #248] and x16, x0, #1 str x16, [sp, #248] lsr x0, x0, #1 From b928c7d9698281eedfa7c246f6f9bf9ac42c4121 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Tue, 24 Mar 2026 00:04:31 +0000 Subject: [PATCH 5/6] Avoid using x18/w18 as general purpose register to fix MacOS ABI breakage Signed-off-by: Milos Puzovic --- .../mlas/lib/aarch64/SconvNchwcKernelNeon.S | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S index d301c61ffc257..0b72caaa1c850 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S @@ -418,7 +418,7 @@ Abstract: // Prologue and callee-saved register spill. // Save callee-saved SIMD registers v8-v15 per AArch64 ABI. - stp x29, x30, [sp, #-256]! + stp x29, x30, [sp, #-272]! mov x29, sp stp x19, x20, [sp, #16] stp x21, x22, [sp, #32] @@ -448,26 +448,26 @@ Abstract: ldr x15, [x9, #40] // DilatedInputWidth (bytes) ldr x16, [x9, #48] // OutputCountLeftPad ldr x17, [x9, #56] // OutputCount - ldr x18, [x9, #64] // OutputCountRightPad ldr x6, [x9, #72] // Bias ldr w8, [x9, #80] // KernelFlags // Early exit when nothing to compute. + mov x26, x10 // OutputStride (bytes) + ldr x10, [x9, #64] // OutputCountRightPad add x0, x16, x17 - add x0, x0, x18 // x0 = TotalOutputCount + add x0, x0, x10 // x0 = TotalOutputCount cbz x0, .Lepilogue cbz x24, .Lepilogue - // Spill the output counts so that x16/x17/x18 can be used as scratch. + // Spill the output counts so that x16/x17 can be used as scratch. str x16, [sp, #224] str x17, [sp, #232] - str x18, [sp, #240] + str x10, [sp, #240] + str w8, [sp, #256] - mov x26, x10 // OutputStride (bytes) mov x27, x11 // KernelHeight mov x28, x12 // KernelWidth mov x17, x6 // Bias - mov w18, w8 // KernelFlags // Set up a zero vector for ReLU. movi v31.4s, #0 @@ -479,7 +479,8 @@ Abstract: // Fast path when no post-processing flags are enabled. This removes // repeated flag checks and branches from the steady-state loops. - tst w18, #7 + ldr w9, [sp, #256] + tst w9, #7 b.eq .Lkernel_flags0 // Process the left padded output region with bounds checks. @@ -509,7 +510,8 @@ Abstract: add x12, x5, x2 // Conditionally accumulate the existing output. - tst w18, #1 + ldr w9, [sp, #256] + tst w9, #1 b.eq .Lskip_accumulate ldp q16, q17, [x12, #0] ldp q18, q19, [x12, #32] @@ -521,7 +523,7 @@ Abstract: .Lskip_accumulate: // Conditionally add bias. - tst w18, #2 + tst w9, #2 b.eq .Lskip_bias ldp q16, q17, [x6, #0] ldp q18, q19, [x6, #32] @@ -533,7 +535,7 @@ Abstract: .Lskip_bias: // Conditionally apply ReLU. - tst w18, #4 + tst w9, #4 b.eq .Lskip_relu fmax v0.4s, v0.4s, v31.4s fmax v1.4s, v1.4s, v31.4s @@ -612,7 +614,8 @@ Abstract: add x10, x12, #128 // Conditionally accumulate the existing output. - tst w18, #1 + ldr w9, [sp, #256] + tst w9, #1 b.eq .Lskip_accumulate_triad ldp q16, q17, [x12, #0] ldp q18, q19, [x12, #32] @@ -636,7 +639,7 @@ Abstract: .Lskip_accumulate_triad: // Conditionally add bias. - tst w18, #2 + tst w9, #2 b.eq .Lskip_bias_triad ldp q16, q17, [x6, #0] ldp q18, q19, [x6, #32] @@ -656,7 +659,7 @@ Abstract: .Lskip_bias_triad: // Conditionally apply ReLU. - tst w18, #4 + tst w9, #4 b.eq .Lskip_relu_triad fmax v0.4s, v0.4s, v31.4s fmax v1.4s, v1.4s, v31.4s @@ -730,7 +733,8 @@ Abstract: add x11, x12, #64 // Conditionally accumulate the existing output. - tst w18, #1 + ldr w9, [sp, #256] + tst w9, #1 b.eq .Lskip_accumulate_pair ldp q16, q17, [x12, #0] ldp q18, q19, [x12, #32] @@ -748,7 +752,7 @@ Abstract: .Lskip_accumulate_pair: // Conditionally add bias. - tst w18, #2 + tst w9, #2 b.eq .Lskip_bias_pair ldp q16, q17, [x6, #0] ldp q18, q19, [x6, #32] @@ -764,7 +768,7 @@ Abstract: .Lskip_bias_pair: // Conditionally apply ReLU. - tst w18, #4 + tst w9, #4 b.eq .Lskip_relu_pair fmax v0.4s, v0.4s, v31.4s fmax v1.4s, v1.4s, v31.4s @@ -823,7 +827,8 @@ Abstract: add x12, x5, x2 // Conditionally accumulate the existing output. - tst w18, #1 + ldr w9, [sp, #256] + tst w9, #1 b.eq .Lskip_accumulate_mid ldp q16, q17, [x12, #0] ldp q18, q19, [x12, #32] @@ -835,7 +840,7 @@ Abstract: .Lskip_accumulate_mid: // Conditionally add bias. - tst w18, #2 + tst w9, #2 b.eq .Lskip_bias_mid ldp q16, q17, [x6, #0] ldp q18, q19, [x6, #32] @@ -847,7 +852,7 @@ Abstract: .Lskip_bias_mid: // Conditionally apply ReLU. - tst w18, #4 + tst w9, #4 b.eq .Lskip_relu_mid fmax v0.4s, v0.4s, v31.4s fmax v1.4s, v1.4s, v31.4s @@ -901,7 +906,8 @@ Abstract: add x12, x5, x2 // Conditionally accumulate the existing output. - tst w18, #1 + ldr w9, [sp, #256] + tst w9, #1 b.eq .Lskip_accumulate_right ldp q16, q17, [x12, #0] ldp q18, q19, [x12, #32] @@ -913,7 +919,7 @@ Abstract: .Lskip_accumulate_right: // Conditionally add bias. - tst w18, #2 + tst w9, #2 b.eq .Lskip_bias_right ldp q16, q17, [x6, #0] ldp q18, q19, [x6, #32] @@ -925,7 +931,7 @@ Abstract: .Lskip_bias_right: // Conditionally apply ReLU. - tst w18, #4 + tst w9, #4 b.eq .Lskip_relu_right fmax v0.4s, v0.4s, v31.4s fmax v1.4s, v1.4s, v31.4s @@ -1291,7 +1297,7 @@ Abstract: ldp x23, x24, [sp, #48] ldp x21, x22, [sp, #32] ldp x19, x20, [sp, #16] - ldp x29, x30, [sp], #256 + ldp x29, x30, [sp], #272 ret #endif // defined(__aarch64__) && !defined(_WIN32) From 89809bbe6e7cd948ad72c4e621f0c94cdf444005 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Tue, 24 Mar 2026 19:10:53 +0000 Subject: [PATCH 6/6] Address comments from Copilot --- .../test/mlas/bench/bench_sconv_nchwc.cpp | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp b/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp index a472143e4840e..5e527a48a683d 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp @@ -82,10 +82,18 @@ static void RunDirectNchwcKernel(const size_t input_channels, // Outer IC-block loop is required to accumulate all channel-block contributions: // first block initializes with bias, remaining blocks accumulate into output. + // the production driver sets ACCUMULATE_OUTPUT for all but the first IC block, + // and applies BIAS_ADDITION only on the final IC block. for (size_t icb = 0; icb < input_channel_blocks; ++icb) { - const unsigned kernel_flags = - (icb == 0) ? MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION : MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT; - + const bool is_first_ic_block = (icb == 0); + const bool is_last_ic_block = (icb + 1 == input_channel_blocks); + unsigned kernel_flags = 0; + if (!is_first_ic_block) { + kernel_flags |= MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT; + } + if (is_last_ic_block && bias != nullptr) { + kernel_flags |= MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION; + } const float* ic_block_input = input + icb * input_height * input_width * NchwcBlockSize; const float* ic_block_filter = filter + icb * kernel_size * NchwcBlockSize * NchwcBlockSize; @@ -111,9 +119,13 @@ static void RunDirectNchwcKernel(const size_t input_channels, const size_t effective_kernel_height = kh_end - kh_start; const ptrdiff_t input_h_base = output_origin_h + static_cast(kh_start * dilation); + ptrdiff_t kernel_row_index = + input_h_base * static_cast(input_width) - static_cast(pad_left); + if (kernel_row_index < 0) { + kernel_row_index = 0; + } const float* kernel_input_row = - ic_block_input + static_cast(NchwcBlockSize) * - (input_h_base * static_cast(input_width) - static_cast(pad_left)); + ic_block_input + static_cast(NchwcBlockSize) * kernel_row_index; const float* input_base_row = ic_block_input + static_cast(NchwcBlockSize) * (input_h_base * static_cast(input_width)); const float* kernel_filter = ic_block_filter + kh_start * kernel_width * NchwcBlockSize * NchwcBlockSize;