diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 29ff1780f12e2..73ab39048de4f 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -334,6 +334,7 @@ function (setup_arm_neon_nchwc) # separate assembly file allows tighter control over register allocation # and avoids the overhead of C++/intrinsics based code generation. ${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvNchwcKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S ) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S new file mode 100644 index 0000000000000..1c2a01a0db0b9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvNchwcKernelNeon.S @@ -0,0 +1,962 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + SconvNchwcKernelNeon.S + +Abstract: + Hand written AArch64 vectorised kernel used by the convolution path + (NCHWc activations and NCHWc weights). The kernel computes one output + row and processes up to four 16-wide filter blocks (FilterCount <= 4). +--*/ + +#include "asmmacro.h" + + // Stack layout for the callee-saved register frame plus the spilled + // tail of the convolution parameter list. The first eight integer + // arguments arrive in x0-x7 and everything after that is read back + // from fixed offsets relative to sp. + + .equ .LFrame_SavedRegs, (5*16 + 4*32 + 16 + 32) + .equ .LFrame_x19_x20, 0 + .equ .LFrame_x21_x22, 16 + .equ .LFrame_x23_x24, 32 + .equ .LFrame_x25_x26, 48 + .equ .LFrame_x27_x28, 64 + .equ .LFrame_q8_q9, 80 + .equ .LFrame_q10_q11, 112 + .equ .LFrame_q12_q13, 144 + .equ .LFrame_q14_q15, 176 + .equ .LFrame_InputBaseSaved, 208 + .equ .LFrame_FilterBaseSaved, 216 + .equ .LFrame_OutputBaseSaved, 224 + .equ .LFrame_LrSaved, 232 + + .equ .KO_OutputStride, (0 + .LFrame_SavedRegs) + .equ .KO_KernelHeight, (8 + .LFrame_SavedRegs) + .equ .KO_KernelWidth, (16 + .LFrame_SavedRegs) + .equ .KO_InputBase, (24 + .LFrame_SavedRegs) + .equ .KO_InputWidth, (32 + .LFrame_SavedRegs) + .equ .KO_DilatedInputWidth, (40 + .LFrame_SavedRegs) + .equ .KO_OutputCountLeftPad, (48 + .LFrame_SavedRegs) + .equ .KO_OutputCount, (56 + .LFrame_SavedRegs) + .equ .KO_OutputCountRightPad, (64 + .LFrame_SavedRegs) + .equ .KO_Bias, (72 + .LFrame_SavedRegs) + .equ .KO_Flags, (80 + .LFrame_SavedRegs) + + .text + + // --------------------------------------------------------------------- + // Helper macros that build the computation, postprocess, and store + // phases used by the NCHWc kernel body. + // --------------------------------------------------------------------- + + // Zero the accumulator vectors for a 1-4 filter-block microkernel. + .macro CLEAR_ACCUM N + eor v0.16b,v0.16b,v0.16b + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + .if \N >= 2 + eor v4.16b,v4.16b,v4.16b + eor v5.16b,v5.16b,v5.16b + eor v6.16b,v6.16b,v6.16b + eor v7.16b,v7.16b,v7.16b + .endif + .if \N >= 3 + eor v8.16b,v8.16b,v8.16b + eor v9.16b,v9.16b,v9.16b + eor v10.16b,v10.16b,v10.16b + eor v11.16b,v11.16b,v11.16b + .endif + .if \N >= 4 + eor v12.16b,v12.16b,v12.16b + eor v13.16b,v13.16b,v13.16b + eor v14.16b,v14.16b,v14.16b + eor v15.16b,v15.16b,v15.16b + .endif + .endm + + // Finalize one output position for N filter blocks by optionally + // adding the prior output, adding bias, applying ReLU, and storing. + // x27 points at the first output block, x8 is the output stride, + // x17 is the bias pointer, v31 is zero, and w19 holds KernelFlags. + .macro POSTPROCESS_STORE_1OUTPUT_NFILTERS N + tbz w19,#0,1f + ldr q16,[x27] + ldr q17,[x27,#16] + ldr q18,[x27,#32] + ldr q19,[x27,#48] + fadd v0.4s,v0.4s,v16.4s + fadd v1.4s,v1.4s,v17.4s + fadd v2.4s,v2.4s,v18.4s + fadd v3.4s,v3.4s,v19.4s + .if \N >= 2 + add x24,x27,x8 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v4.4s,v4.4s,v16.4s + fadd v5.4s,v5.4s,v17.4s + fadd v6.4s,v6.4s,v18.4s + fadd v7.4s,v7.4s,v19.4s + .endif + .if \N >= 3 + add x24,x27,x8,lsl #1 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v8.4s,v8.4s,v16.4s + fadd v9.4s,v9.4s,v17.4s + fadd v10.4s,v10.4s,v18.4s + fadd v11.4s,v11.4s,v19.4s + .endif + .if \N >= 4 + add x24,x27,x8 + add x24,x24,x8,lsl #1 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v12.4s,v12.4s,v16.4s + fadd v13.4s,v13.4s,v17.4s + fadd v14.4s,v14.4s,v18.4s + fadd v15.4s,v15.4s,v19.4s + .endif + 1: + tbz w19,#1,2f + ldr q16,[x17] + ldr q17,[x17,#16] + ldr q18,[x17,#32] + ldr q19,[x17,#48] + fadd v0.4s,v0.4s,v16.4s + fadd v1.4s,v1.4s,v17.4s + fadd v2.4s,v2.4s,v18.4s + fadd v3.4s,v3.4s,v19.4s + .if \N >= 2 + add x24,x17,#64 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v4.4s,v4.4s,v16.4s + fadd v5.4s,v5.4s,v17.4s + fadd v6.4s,v6.4s,v18.4s + fadd v7.4s,v7.4s,v19.4s + .endif + .if \N >= 3 + add x24,x17,#128 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v8.4s,v8.4s,v16.4s + fadd v9.4s,v9.4s,v17.4s + fadd v10.4s,v10.4s,v18.4s + fadd v11.4s,v11.4s,v19.4s + .endif + .if \N >= 4 + add x24,x17,#192 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v12.4s,v12.4s,v16.4s + fadd v13.4s,v13.4s,v17.4s + fadd v14.4s,v14.4s,v18.4s + fadd v15.4s,v15.4s,v19.4s + .endif + 2: + tbz w19,#2,3f + fmax v0.4s,v0.4s,v31.4s + fmax v1.4s,v1.4s,v31.4s + fmax v2.4s,v2.4s,v31.4s + fmax v3.4s,v3.4s,v31.4s + .if \N >= 2 + fmax v4.4s,v4.4s,v31.4s + fmax v5.4s,v5.4s,v31.4s + fmax v6.4s,v6.4s,v31.4s + fmax v7.4s,v7.4s,v31.4s + .endif + .if \N >= 3 + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s + .endif + .if \N >= 4 + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s + .endif + 3: + str q0,[x27] + str q1,[x27,#16] + str q2,[x27,#32] + str q3,[x27,#48] + .if \N >= 2 + add x24,x27,x8 + str q4,[x24] + str q5,[x24,#16] + str q6,[x24,#32] + str q7,[x24,#48] + .endif + .if \N >= 3 + add x24,x27,x8,lsl #1 + str q8,[x24] + str q9,[x24,#16] + str q10,[x24,#32] + str q11,[x24,#48] + .endif + .if \N >= 4 + add x24,x27,x8 + add x24,x24,x8,lsl #1 + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] + .endif + .endm + + // Zero the wider accumulator set used when computing two outputs at + // once for the same filter block. + .macro CLEAR_ACCUM2 N + eor v0.16b,v0.16b,v0.16b + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + eor v4.16b,v4.16b,v4.16b + eor v5.16b,v5.16b,v5.16b + eor v6.16b,v6.16b,v6.16b + eor v7.16b,v7.16b,v7.16b + .if \N >= 2 + eor v8.16b,v8.16b,v8.16b + eor v9.16b,v9.16b,v9.16b + eor v10.16b,v10.16b,v10.16b + eor v11.16b,v11.16b,v11.16b + eor v12.16b,v12.16b,v12.16b + eor v13.16b,v13.16b,v13.16b + eor v14.16b,v14.16b,v14.16b + eor v15.16b,v15.16b,v15.16b + .endif + .if \N >= 3 + eor v16.16b,v16.16b,v16.16b + eor v17.16b,v17.16b,v17.16b + eor v18.16b,v18.16b,v18.16b + eor v19.16b,v19.16b,v19.16b + eor v20.16b,v20.16b,v20.16b + eor v21.16b,v21.16b,v21.16b + eor v22.16b,v22.16b,v22.16b + eor v23.16b,v23.16b,v23.16b + .endif + .endm + + // Finalize and store two output positions for a single filter block. + .macro POSTPROCESS_STORE_2OUTPUTS_FC1_NOPAD + add x24,x27,#64 + tbz w19,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v0.4s,v0.4s,v26.4s + fadd v1.4s,v1.4s,v27.4s + fadd v2.4s,v2.4s,v28.4s + fadd v3.4s,v3.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v4.4s,v4.4s,v26.4s + fadd v5.4s,v5.4s,v27.4s + fadd v6.4s,v6.4s,v28.4s + fadd v7.4s,v7.4s,v29.4s + 1: + tbz w19,#1,2f + ldr q26,[x17] + ldr q27,[x17,#16] + ldr q28,[x17,#32] + ldr q29,[x17,#48] + fadd v0.4s,v0.4s,v26.4s + fadd v1.4s,v1.4s,v27.4s + fadd v2.4s,v2.4s,v28.4s + fadd v3.4s,v3.4s,v29.4s + fadd v4.4s,v4.4s,v26.4s + fadd v5.4s,v5.4s,v27.4s + fadd v6.4s,v6.4s,v28.4s + fadd v7.4s,v7.4s,v29.4s + 2: + tbz w19,#2,3f + fmax v0.4s,v0.4s,v31.4s + fmax v1.4s,v1.4s,v31.4s + fmax v2.4s,v2.4s,v31.4s + fmax v3.4s,v3.4s,v31.4s + fmax v4.4s,v4.4s,v31.4s + fmax v5.4s,v5.4s,v31.4s + fmax v6.4s,v6.4s,v31.4s + fmax v7.4s,v7.4s,v31.4s + 3: + str q0,[x27] + str q1,[x27,#16] + str q2,[x27,#32] + str q3,[x27,#48] + str q4,[x24] + str q5,[x24,#16] + str q6,[x24,#32] + str q7,[x24,#48] + .endm + + // Finalize and store two output positions for two filter blocks. + .macro POSTPROCESS_STORE_2OUTPUTS_FC2_NOPAD + POSTPROCESS_STORE_2OUTPUTS_FC1_NOPAD + add x24,x27,x8 + add x24,x24,#64 + add x27,x27,x8 + tbz w19,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s + 1: + tbz w19,#1,2f + add x23,x17,#64 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s + 2: + tbz w19,#2,3f + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s + 3: + str q8,[x27] + str q9,[x27,#16] + str q10,[x27,#32] + str q11,[x27,#48] + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] + sub x27,x27,x8 + .endm + + // Finalize and store two output positions for three filter blocks. + .macro POSTPROCESS_STORE_2OUTPUTS_FC3_NOPAD + POSTPROCESS_STORE_2OUTPUTS_FC1_NOPAD + add x24,x27,x8 + add x24,x24,#64 + add x27,x27,x8 + tbz w19,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s + 1: + tbz w19,#1,2f + add x23,x17,#64 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s + 2: + tbz w19,#2,3f + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s + 3: + str q8,[x27] + str q9,[x27,#16] + str q10,[x27,#32] + str q11,[x27,#48] + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] + add x27,x27,x8 + add x24,x27,#64 + tbz w19,#0,4f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v16.4s,v16.4s,v26.4s + fadd v17.4s,v17.4s,v27.4s + fadd v18.4s,v18.4s,v28.4s + fadd v19.4s,v19.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v20.4s,v20.4s,v26.4s + fadd v21.4s,v21.4s,v27.4s + fadd v22.4s,v22.4s,v28.4s + fadd v23.4s,v23.4s,v29.4s + 4: + tbz w19,#1,5f + add x23,x17,#128 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v16.4s,v16.4s,v26.4s + fadd v17.4s,v17.4s,v27.4s + fadd v18.4s,v18.4s,v28.4s + fadd v19.4s,v19.4s,v29.4s + fadd v20.4s,v20.4s,v26.4s + fadd v21.4s,v21.4s,v27.4s + fadd v22.4s,v22.4s,v28.4s + fadd v23.4s,v23.4s,v29.4s + 5: + tbz w19,#2,6f + fmax v16.4s,v16.4s,v31.4s + fmax v17.4s,v17.4s,v31.4s + fmax v18.4s,v18.4s,v31.4s + fmax v19.4s,v19.4s,v31.4s + fmax v20.4s,v20.4s,v31.4s + fmax v21.4s,v21.4s,v31.4s + fmax v22.4s,v22.4s,v31.4s + fmax v23.4s,v23.4s,v31.4s + 6: + str q16,[x27] + str q17,[x27,#16] + str q18,[x27,#32] + str q19,[x27,#48] + str q20,[x24] + str q21,[x24,#16] + str q22,[x24,#32] + str q23,[x24,#48] + sub x27,x27,x8 + sub x27,x27,x8 + .endm + + // Consume one 16x16 weight row and accumulate it into one output + // position using a selected lane from the loaded activation row. + .macro FMLA_ROW AReg, Lane + ld1 {v24.4s,v25.4s,v26.4s,v27.4s},[x28],#64 + fmla v0.4s,v24.4s,\AReg\().s[\Lane] + fmla v1.4s,v25.4s,\AReg\().s[\Lane] + fmla v2.4s,v26.4s,\AReg\().s[\Lane] + fmla v3.4s,v27.4s,\AReg\().s[\Lane] + .endm + + // Consume one 16x16 weight row and accumulate it into two output + // positions, reusing the same filter load for both outputs. + .macro FMLA_ROW2 AReg0, Lane0, AReg1, Lane1 + ld1 {v24.4s,v25.4s,v26.4s,v27.4s},[x28],#64 + fmla v0.4s,v24.4s,\AReg0\().s[\Lane0] + fmla v1.4s,v25.4s,\AReg0\().s[\Lane0] + fmla v2.4s,v26.4s,\AReg0\().s[\Lane0] + fmla v3.4s,v27.4s,\AReg0\().s[\Lane0] + fmla v4.4s,v24.4s,\AReg1\().s[\Lane1] + fmla v5.4s,v25.4s,\AReg1\().s[\Lane1] + fmla v6.4s,v26.4s,\AReg1\().s[\Lane1] + fmla v7.4s,v27.4s,\AReg1\().s[\Lane1] + .endm + + // Consume one 16x16 weight row for two filter blocks and accumulate it + // into one output position. + .macro FMLA_ROW_FC2 AReg, Lane + ld1 {v24.4s,v25.4s,v26.4s,v27.4s},[x28],#64 + fmla v0.4s,v24.4s,\AReg\().s[\Lane] + fmla v1.4s,v25.4s,\AReg\().s[\Lane] + fmla v2.4s,v26.4s,\AReg\().s[\Lane] + fmla v3.4s,v27.4s,\AReg\().s[\Lane] + ld1 {v24.4s,v25.4s,v26.4s,v27.4s},[x13],#64 + fmla v4.4s,v24.4s,\AReg\().s[\Lane] + fmla v5.4s,v25.4s,\AReg\().s[\Lane] + fmla v6.4s,v26.4s,\AReg\().s[\Lane] + fmla v7.4s,v27.4s,\AReg\().s[\Lane] + .endm + + // Consume one 16x16 weight row for two filter blocks and accumulate it + // into two output positions. + .macro FMLA_ROW2_FC2 AReg0, Lane0, AReg1, Lane1 + ld1 {v24.4s,v25.4s,v26.4s,v27.4s},[x28],#64 + fmla v0.4s,v24.4s,\AReg0\().s[\Lane0] + fmla v1.4s,v25.4s,\AReg0\().s[\Lane0] + fmla v2.4s,v26.4s,\AReg0\().s[\Lane0] + fmla v3.4s,v27.4s,\AReg0\().s[\Lane0] + fmla v4.4s,v24.4s,\AReg1\().s[\Lane1] + fmla v5.4s,v25.4s,\AReg1\().s[\Lane1] + fmla v6.4s,v26.4s,\AReg1\().s[\Lane1] + fmla v7.4s,v27.4s,\AReg1\().s[\Lane1] + ld1 {v24.4s,v25.4s,v26.4s,v27.4s},[x13],#64 + fmla v8.4s,v24.4s,\AReg0\().s[\Lane0] + fmla v9.4s,v25.4s,\AReg0\().s[\Lane0] + fmla v10.4s,v26.4s,\AReg0\().s[\Lane0] + fmla v11.4s,v27.4s,\AReg0\().s[\Lane0] + fmla v12.4s,v24.4s,\AReg1\().s[\Lane1] + fmla v13.4s,v25.4s,\AReg1\().s[\Lane1] + fmla v14.4s,v26.4s,\AReg1\().s[\Lane1] + fmla v15.4s,v27.4s,\AReg1\().s[\Lane1] + .endm + + // Compute one output position for the interior no-padding case. The + // outer loop walks kernel rows, the inner loop walks kernel columns, + // and each iteration multiplies one 16-lane input block by one 16x16 + // filter tile before handing off to the single-output postprocess path. + .macro CONV_NCHWC_FC1_1OUTPUT_NOPAD + CLEAR_ACCUM 1 + mov x28,x1 + mov x21,x0 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,3f + ldr x10,[sp,#.KO_KernelWidth] + 1: mov x25,x21 + mov x11,x10 + 2: + ldp q20,q21,[x25] + ldp q22,q23,[x25,#32] + FMLA_ROW v20,0 + FMLA_ROW v20,1 + FMLA_ROW v20,2 + FMLA_ROW v20,3 + FMLA_ROW v21,0 + FMLA_ROW v21,1 + FMLA_ROW v21,2 + FMLA_ROW v21,3 + FMLA_ROW v22,0 + FMLA_ROW v22,1 + FMLA_ROW v22,2 + FMLA_ROW v22,3 + FMLA_ROW v23,0 + FMLA_ROW v23,1 + FMLA_ROW v23,2 + FMLA_ROW v23,3 + add x25,x25,x4 + subs x11,x11,#1 + b.ne 2b + add x21,x25,x6 + subs x9,x9,#1 + b.ne 1b + 3: POSTPROCESS_STORE_1OUTPUT_NFILTERS 1 + .endm + + // Compute one output position for two filter blocks in the interior + // no-padding case. + .macro CONV_NCHWC_FC2_1OUTPUT_NOPAD + CLEAR_ACCUM 2 + mov x28,x1 + add x13,x1,x7 + mov x21,x0 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,3f + ldr x10,[sp,#.KO_KernelWidth] + 1: mov x25,x21 + mov x11,x10 + 2: + ldp q20,q21,[x25] + ldp q22,q23,[x25,#32] + FMLA_ROW_FC2 v20,0 + FMLA_ROW_FC2 v20,1 + FMLA_ROW_FC2 v20,2 + FMLA_ROW_FC2 v20,3 + FMLA_ROW_FC2 v21,0 + FMLA_ROW_FC2 v21,1 + FMLA_ROW_FC2 v21,2 + FMLA_ROW_FC2 v21,3 + FMLA_ROW_FC2 v22,0 + FMLA_ROW_FC2 v22,1 + FMLA_ROW_FC2 v22,2 + FMLA_ROW_FC2 v22,3 + FMLA_ROW_FC2 v23,0 + FMLA_ROW_FC2 v23,1 + FMLA_ROW_FC2 v23,2 + FMLA_ROW_FC2 v23,3 + add x25,x25,x4 + subs x11,x11,#1 + b.ne 2b + add x21,x25,x6 + subs x9,x9,#1 + b.ne 1b + 3: POSTPROCESS_STORE_1OUTPUT_NFILTERS 2 + .endm + + // Compute two adjacent output positions together for the interior + // no-padding case. The loops are the same as the single-output form, + // but each filter load is shared across two activation positions to + // improve throughput before the paired postprocess/store step. + .macro CONV_NCHWC_FC1_2OUTPUTS_NOPAD + CLEAR_ACCUM2 1 + mov x28,x1 + mov x21,x0 + add x22,x0,x3 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,3f + ldr x10,[sp,#.KO_KernelWidth] + 1: mov x25,x21 + mov x26,x22 + mov x11,x10 + 2: + ldp q16,q17,[x25] + ldp q18,q19,[x25,#32] + ldp q20,q21,[x26] + ldp q22,q23,[x26,#32] + FMLA_ROW2 v16,0,v20,0 + FMLA_ROW2 v16,1,v20,1 + FMLA_ROW2 v16,2,v20,2 + FMLA_ROW2 v16,3,v20,3 + FMLA_ROW2 v17,0,v21,0 + FMLA_ROW2 v17,1,v21,1 + FMLA_ROW2 v17,2,v21,2 + FMLA_ROW2 v17,3,v21,3 + FMLA_ROW2 v18,0,v22,0 + FMLA_ROW2 v18,1,v22,1 + FMLA_ROW2 v18,2,v22,2 + FMLA_ROW2 v18,3,v22,3 + FMLA_ROW2 v19,0,v23,0 + FMLA_ROW2 v19,1,v23,1 + FMLA_ROW2 v19,2,v23,2 + FMLA_ROW2 v19,3,v23,3 + add x25,x25,x4 + add x26,x26,x4 + subs x11,x11,#1 + b.ne 2b + add x21,x25,x6 + add x22,x26,x6 + subs x9,x9,#1 + b.ne 1b + 3: POSTPROCESS_STORE_2OUTPUTS_FC1_NOPAD + .endm + + // Compute two adjacent output positions together for two filter blocks + // in the interior no-padding case. + .macro CONV_NCHWC_FC2_2OUTPUTS_NOPAD + CLEAR_ACCUM2 2 + mov x28,x1 + add x13,x1,x7 + mov x21,x0 + add x22,x0,x3 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,3f + ldr x10,[sp,#.KO_KernelWidth] + 1: mov x25,x21 + mov x26,x22 + mov x11,x10 + 2: + ldp q16,q17,[x25] + ldp q18,q19,[x25,#32] + ldp q20,q21,[x26] + ldp q22,q23,[x26,#32] + FMLA_ROW2_FC2 v16,0,v20,0 + FMLA_ROW2_FC2 v16,1,v20,1 + FMLA_ROW2_FC2 v16,2,v20,2 + FMLA_ROW2_FC2 v16,3,v20,3 + FMLA_ROW2_FC2 v17,0,v21,0 + FMLA_ROW2_FC2 v17,1,v21,1 + FMLA_ROW2_FC2 v17,2,v21,2 + FMLA_ROW2_FC2 v17,3,v21,3 + FMLA_ROW2_FC2 v18,0,v22,0 + FMLA_ROW2_FC2 v18,1,v22,1 + FMLA_ROW2_FC2 v18,2,v22,2 + FMLA_ROW2_FC2 v18,3,v22,3 + FMLA_ROW2_FC2 v19,0,v23,0 + FMLA_ROW2_FC2 v19,1,v23,1 + FMLA_ROW2_FC2 v19,2,v23,2 + FMLA_ROW2_FC2 v19,3,v23,3 + add x25,x25,x4 + add x26,x26,x4 + subs x11,x11,#1 + b.ne 2b + add x21,x25,x6 + add x22,x26,x6 + subs x9,x9,#1 + b.ne 1b + 3: POSTPROCESS_STORE_2OUTPUTS_FC2_NOPAD + .endm + + // Restore the nonvolatile general-purpose and vector registers and + // return to the C++ caller. + .macro RESTORE_NONVOLATILE_AND_RETURN + ldp q14,q15,[sp,#.LFrame_q14_q15] + ldp q12,q13,[sp,#.LFrame_q12_q13] + ldp q10,q11,[sp,#.LFrame_q10_q11] + ldp q8,q9,[sp,#.LFrame_q8_q9] + ldp x27,x28,[sp,#.LFrame_x27_x28] + ldp x25,x26,[sp,#.LFrame_x25_x26] + ldp x23,x24,[sp,#.LFrame_x23_x24] + ldp x21,x22,[sp,#.LFrame_x21_x22] + ldr x30,[sp,#.LFrame_LrSaved] + ldp x19,x20,[sp],#.LFrame_SavedRegs + ret + .endm + +// ----------------------------------------------------------------------------- +// void MlasConvNchwcFloatKernelNeonAsm(...) +// Implements the convolution micro-kernel for the NCHWc input format. +// ----------------------------------------------------------------------------- + + FUNCTION_ENTRY MlasConvNchwcFloatKernelNeonAsm + + stp x19,x20,[sp,#-.LFrame_SavedRegs]! + stp x21,x22,[sp,#.LFrame_x21_x22] + stp x23,x24,[sp,#.LFrame_x23_x24] + stp x25,x26,[sp,#.LFrame_x25_x26] + stp x27,x28,[sp,#.LFrame_x27_x28] + stp q8,q9,[sp,#.LFrame_q8_q9] + stp q10,q11,[sp,#.LFrame_q10_q11] + stp q12,q13,[sp,#.LFrame_q12_q13] + stp q14,q15,[sp,#.LFrame_q14_q15] + str x0,[sp,#.LFrame_InputBaseSaved] + str x1,[sp,#.LFrame_FilterBaseSaved] + str x2,[sp,#.LFrame_OutputBaseSaved] + str x30,[sp,#.LFrame_LrSaved] + + ldr x1,[sp,#.LFrame_FilterBaseSaved] + ldr x2,[sp,#.LFrame_OutputBaseSaved] + ldr x8,[sp,#.KO_OutputStride] + ldr x14,[sp,#.KO_OutputCountLeftPad] + ldr x15,[sp,#.KO_OutputCount] + ldr x16,[sp,#.KO_OutputCountRightPad] + ldr x17,[sp,#.KO_Bias] + ldr w19,[sp,#.KO_Flags] + + eor v31.16b,v31.16b,v31.16b + + cmp x5,#1 + b.eq .LFilterCountSupported + cmp x5,#2 + b.eq .LFilterCountSupported + cmp x5,#3 + b.eq .LFilterCountSupportedFC3 + cmp x5,#4 + b.ne .LKernelExit + + cbnz x14,.LKernelExit + cbnz x16,.LKernelExit + cbz x15,.LKernelExit + + mov x20,#0 + + .LNoPadLoopFC4Pass1: + sub x12,x15,x20 + cmp x12,#2 + b.lt .LNoPadTailFC4Pass1 + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC2_2OUTPUTS_NOPAD + add x20,x20,#2 + cmp x20,x15 + blo .LNoPadLoopFC4Pass1 + b .LStartFC4Pass2 + + .LNoPadTailFC4Pass1: + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC2_1OUTPUT_NOPAD + add x20,x20,#1 + cmp x20,x15 + blo .LNoPadLoopFC4Pass1 + + .LStartFC4Pass2: + ldr x1,[sp,#.LFrame_FilterBaseSaved] + add x1,x1,x7 + add x1,x1,x7 + ldr x2,[sp,#.LFrame_OutputBaseSaved] + add x2,x2,x8 + add x2,x2,x8 + add x17,x17,#128 + mov x20,#0 + + .LNoPadLoopFC4Pass2: + sub x12,x15,x20 + cmp x12,#2 + b.lt .LNoPadTailFC4Pass2 + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC2_2OUTPUTS_NOPAD + add x20,x20,#2 + cmp x20,x15 + blo .LNoPadLoopFC4Pass2 + b .LKernelExit + + .LNoPadTailFC4Pass2: + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC2_1OUTPUT_NOPAD + add x20,x20,#1 + cmp x20,x15 + blo .LNoPadLoopFC4Pass2 + b .LKernelExit + + .LFilterCountSupportedFC3: + + cbnz x14,.LKernelExit + cbnz x16,.LKernelExit + cbz x15,.LKernelExit + + mov x20,#0 + + .LNoPadLoopFC3Pass1: + sub x12,x15,x20 + cmp x12,#2 + b.lt .LNoPadTailFC3Pass1 + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC2_2OUTPUTS_NOPAD + add x20,x20,#2 + cmp x20,x15 + blo .LNoPadLoopFC3Pass1 + b .LStartFC3Pass2 + + .LNoPadTailFC3Pass1: + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC2_1OUTPUT_NOPAD + add x20,x20,#1 + cmp x20,x15 + blo .LNoPadLoopFC3Pass1 + + .LStartFC3Pass2: + ldr x1,[sp,#.LFrame_FilterBaseSaved] + add x1,x1,x7 + add x1,x1,x7 + ldr x2,[sp,#.LFrame_OutputBaseSaved] + add x2,x2,x8 + add x2,x2,x8 + add x17,x17,#128 + mov x20,#0 + + .LNoPadLoopFC3Pass2: + sub x12,x15,x20 + cmp x12,#2 + b.lt .LNoPadTailFC3Pass2 + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC1_2OUTPUTS_NOPAD + add x20,x20,#2 + cmp x20,x15 + blo .LNoPadLoopFC3Pass2 + b .LKernelExit + + .LNoPadTailFC3Pass2: + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + CONV_NCHWC_FC1_1OUTPUT_NOPAD + add x20,x20,#1 + cmp x20,x15 + blo .LNoPadLoopFC3Pass2 + b .LKernelExit + + .LFilterCountSupported: + cbnz x14,.LKernelExit + cbnz x16,.LKernelExit + cbz x15,.LKernelExit + + mov x20,#0 + + .LNoPadLoop1: + sub x12,x15,x20 + cmp x12,#2 + b.lt .LNoPadTail1 + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + cmp x5,#1 + b.eq .LNoPad2FC1 + CONV_NCHWC_FC2_2OUTPUTS_NOPAD + b .LNoPad2Done + + .LNoPad2FC1: + CONV_NCHWC_FC1_2OUTPUTS_NOPAD + + .LNoPad2Done: + add x20,x20,#2 + cmp x20,x15 + blo .LNoPadLoop1 + b .LKernelExit + + .LNoPadTail1: + ldr x21,[sp,#.LFrame_InputBaseSaved] + madd x0,x20,x3,x21 + lsl x23,x20,#6 + add x27,x2,x23 + cmp x5,#1 + b.eq .LNoPad1FC1 + CONV_NCHWC_FC2_1OUTPUT_NOPAD + b .LNoPad1Done + + .LNoPad1FC1: + CONV_NCHWC_FC1_1OUTPUT_NOPAD + + .LNoPad1Done: + add x20,x20,#1 + cmp x20,x15 + blo .LNoPadLoop1 + + .LKernelExit: + RESTORE_NONVOLATILE_AND_RETURN + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 3c0ee29896cd9..4e28ba2bbfa01 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -973,6 +973,8 @@ extern "C" { #if !defined(_WIN32) // AArch64 assembly micro-kernel for direct NCHW convolution MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeonAsm; + // AArch64 assembly micro-kernel for direct NCHWc convolution + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeonAsm; #endif MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; // Intrinsics kernel for depthwise NCHWc convolution diff --git a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp index 59c30a80b53af..100f83c809b0e 100644 --- a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp @@ -53,6 +53,7 @@ void const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-static_cast(BiasAddition))); const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION))); + const size_t StrideWidthElements = StrideWidth / sizeof(float); const size_t DilationWidthElements = DilationWidth / sizeof(float); const size_t FilterStrideElements = FilterStride / sizeof(float); @@ -187,6 +188,54 @@ void // Implementation of MlasConvNchwFloatKernelNeon // +void + MLASCALL + MlasConvNchwcFloatKernelNeonCpp( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + + void MLASCALL MlasConvNchwFloatKernelNeon( @@ -263,7 +312,118 @@ void unsigned KernelFlags ) { - MlasConvFloatKernelNeonImpl( +#if !defined(_WIN32) + if (FilterCount <= 4) { + const size_t StrideWidthElements = StrideWidth / sizeof(float); + + if (OutputCountLeftPad != 0) { + MlasConvNchwcFloatKernelNeonCpp( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + 0, + OutputCountLeftPad, + 0, + Bias, + KernelFlags + ); + } + + if (OutputCount != 0) { + const size_t InteriorOffsetElements = OutputCountLeftPad * StrideWidthElements; + const size_t InteriorOutputOffsetElements = OutputCountLeftPad * BlockSize; + + MlasConvNchwcFloatKernelNeonAsm( + Input + InteriorOffsetElements, + Filter, + Output + InteriorOutputOffsetElements, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + 0, + OutputCount, + 0, + Bias, + KernelFlags + ); + } + + if (OutputCountRightPad != 0) { + const size_t RightOffset = (OutputCountLeftPad + OutputCount) * StrideWidthElements; + const size_t RightOutputOffsetElements = (OutputCountLeftPad + OutputCount) * BlockSize; + + MlasConvNchwcFloatKernelNeonCpp( + Input + RightOffset, + Filter, + Output + RightOutputOffsetElements, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + 0, + OutputCountRightPad, + 0, + Bias, + KernelFlags + ); + } + + return; + } + + if (OutputCountLeftPad == 0 && OutputCountRightPad == 0) { + MlasConvNchwcFloatKernelNeonAsm( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); + return; + } +#endif + + MlasConvNchwcFloatKernelNeonCpp( Input, Filter, Output, diff --git a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h index e2654603073ac..e96c2b5a7cec4 100644 --- a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h @@ -21,6 +21,30 @@ Module Name: #if defined(MLAS_USE_ARM_NEON_NCHWC) +void + MLASCALL + MlasConvNchwcFloatKernelNeonCpp( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ); + #define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 #define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 #define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 diff --git a/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp b/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp new file mode 100644 index 0000000000000..9bffcf11b0d7b --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_sconv_nchwc.cpp @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bench_util.h" + +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + +#include "../../../core/mlas/lib/mlasi.h" +#include "../../../core/mlas/lib/sconv_nchwc_kernel_neon.h" + +#include +#include + +namespace { + +constexpr size_t BlockSize = 16; +constexpr unsigned BenchmarkKernelFlags = MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION; + +enum class NchwcKernelBenchPath { + DirectCpp, + WrapperNoPad, + WrapperPadded, +#if !defined(_WIN32) + DirectAsmNoPad, +#endif +}; + +struct NchwcKernelBenchConfig { + size_t filter_count; + size_t output_count_left_pad; + size_t output_count; + size_t output_count_right_pad; + size_t kernel_height; + size_t kernel_width; + unsigned kernel_flags; +}; + +struct NchwcKernelBenchBuffers { + explicit NchwcKernelBenchBuffers(const NchwcKernelBenchConfig& config) + : total_output_count(config.output_count_left_pad + config.output_count + config.output_count_right_pad), + input_width(config.output_count + config.kernel_width - 1), + total_input_width(config.output_count_left_pad + input_width + config.output_count_right_pad), + filter_elements_per_block(config.kernel_height * config.kernel_width * BlockSize * BlockSize), + output_stride_elements(total_output_count * BlockSize), + stride_width_bytes(BlockSize * sizeof(float)), + dilation_width_bytes(BlockSize * sizeof(float)), + input_width_bytes(input_width * BlockSize * sizeof(float)), + dilated_input_width_bytes(total_input_width * BlockSize * sizeof(float)), + input_stride_bytes(dilated_input_width_bytes - config.kernel_width * dilation_width_bytes), + filter_stride_bytes(filter_elements_per_block * sizeof(float)), + output_stride_bytes(output_stride_elements * sizeof(float)), + input_storage(RandomVectorUniform(config.kernel_height * total_input_width * BlockSize, -1.0f, 1.0f)), + filter(RandomVectorUniform(config.filter_count * filter_elements_per_block, -1.0f, 1.0f)), + bias(RandomVectorUniform(config.filter_count * BlockSize, -0.5f, 0.5f)), + output(RandomVectorUniform(config.filter_count * output_stride_elements, -0.25f, 0.25f)), + input_base(input_storage.data() + config.output_count_left_pad * BlockSize), + input(input_base - config.output_count_left_pad * BlockSize) {} + + const size_t total_output_count; + const size_t input_width; + const size_t total_input_width; + const size_t filter_elements_per_block; + const size_t output_stride_elements; + const size_t stride_width_bytes; + const size_t dilation_width_bytes; + const size_t input_width_bytes; + const size_t dilated_input_width_bytes; + const size_t input_stride_bytes; + const size_t filter_stride_bytes; + const size_t output_stride_bytes; + std::vector input_storage; + std::vector filter; + std::vector bias; + std::vector output; + const float* input_base; + const float* input; +}; + +void RunNchwcKernelBench( + NchwcKernelBenchPath path, + const NchwcKernelBenchConfig& config, + NchwcKernelBenchBuffers& buffers) { + switch (path) { + case NchwcKernelBenchPath::DirectCpp: + // Args: Input, Filter, Output, StrideWidthBytes, DilationWidthBytes, + // FilterCount, InputStrideBytes, FilterStrideBytes, OutputStrideBytes, + // KernelHeight, KernelWidth, InputBase, InputWidthBytes, + // DilatedInputWidthBytes, OutputCountLeftPad, OutputCount, + // OutputCountRightPad, Bias, KernelFlags. + MlasConvNchwcFloatKernelNeonCpp( + buffers.input, + buffers.filter.data(), + buffers.output.data(), + buffers.stride_width_bytes, + buffers.dilation_width_bytes, + config.filter_count, + buffers.input_stride_bytes, + buffers.filter_stride_bytes, + buffers.output_stride_bytes, + config.kernel_height, + config.kernel_width, + buffers.input_base, + buffers.input_width_bytes, + buffers.dilated_input_width_bytes, + config.output_count_left_pad, + config.output_count, + config.output_count_right_pad, + buffers.bias.data(), + config.kernel_flags); + break; + + case NchwcKernelBenchPath::WrapperNoPad: + case NchwcKernelBenchPath::WrapperPadded: + // Same argument order as the direct C++ kernel above. The wrapper uses + // C++ edges and may route interior spans to asm when supported. + MlasConvNchwcFloatKernelNeon( + buffers.input, + buffers.filter.data(), + buffers.output.data(), + buffers.stride_width_bytes, + buffers.dilation_width_bytes, + config.filter_count, + buffers.input_stride_bytes, + buffers.filter_stride_bytes, + buffers.output_stride_bytes, + config.kernel_height, + config.kernel_width, + buffers.input_base, + buffers.input_width_bytes, + buffers.dilated_input_width_bytes, + config.output_count_left_pad, + config.output_count, + config.output_count_right_pad, + buffers.bias.data(), + config.kernel_flags); + break; + +#if !defined(_WIN32) + case NchwcKernelBenchPath::DirectAsmNoPad: + // Same argument order as the direct C++ kernel above. This entry is for + // the no-padding direct asm microkernel path. + MlasConvNchwcFloatKernelNeonAsm( + buffers.input, + buffers.filter.data(), + buffers.output.data(), + buffers.stride_width_bytes, + buffers.dilation_width_bytes, + config.filter_count, + buffers.input_stride_bytes, + buffers.filter_stride_bytes, + buffers.output_stride_bytes, + config.kernel_height, + config.kernel_width, + buffers.input_base, + buffers.input_width_bytes, + buffers.dilated_input_width_bytes, + config.output_count_left_pad, + config.output_count, + config.output_count_right_pad, + buffers.bias.data(), + config.kernel_flags); + break; +#endif + } +} + +void NCHWC_KERNEL_ROW( + benchmark::State& state, + NchwcKernelBenchPath path, + size_t filter_count, + size_t output_count_left_pad, + size_t output_count, + size_t output_count_right_pad, + size_t kernel_height, + size_t kernel_width) { + if (MlasNchwcGetBlockSize() != BlockSize) { + state.SkipWithError("Unexpected NCHWC block size for ARM NEON benchmark"); + return; + } + + const NchwcKernelBenchConfig config{ + filter_count, + output_count_left_pad, + output_count, + output_count_right_pad, + kernel_height, + kernel_width, + BenchmarkKernelFlags}; + + NchwcKernelBenchBuffers buffers(config); + + RunNchwcKernelBench(path, config, buffers); + + for (auto _ : state) { + RunNchwcKernelBench(path, config, buffers); + benchmark::DoNotOptimize(buffers.output.data()); + benchmark::ClobberMemory(); + } + + const int64_t work_items = static_cast(config.filter_count * buffers.total_output_count * BlockSize); + state.SetItemsProcessed(state.iterations() * work_items); +} + +} // namespace + +// BENCHMARK_CAPTURE args after the path are: +// filter_count, output_count_left_pad, output_count, +// output_count_right_pad, kernel_height, kernel_width +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC1_DirectCpp_NoPad, NchwcKernelBenchPath::DirectCpp, 1, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC1_Wrapper_NoPad, NchwcKernelBenchPath::WrapperNoPad, 1, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC1_Wrapper_Padded, NchwcKernelBenchPath::WrapperPadded, 1, 1, 54, 1, 3, 3)->UseRealTime(); + +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC2_DirectCpp_NoPad, NchwcKernelBenchPath::DirectCpp, 2, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC2_Wrapper_NoPad, NchwcKernelBenchPath::WrapperNoPad, 2, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC2_Wrapper_Padded, NchwcKernelBenchPath::WrapperPadded, 2, 1, 54, 1, 3, 3)->UseRealTime(); + +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC3_DirectCpp_NoPad, NchwcKernelBenchPath::DirectCpp, 3, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC3_Wrapper_NoPad, NchwcKernelBenchPath::WrapperNoPad, 3, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC3_Wrapper_Padded, NchwcKernelBenchPath::WrapperPadded, 3, 1, 54, 1, 3, 3)->UseRealTime(); + +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC4_DirectCpp_NoPad, NchwcKernelBenchPath::DirectCpp, 4, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC4_Wrapper_NoPad, NchwcKernelBenchPath::WrapperNoPad, 4, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC4_Wrapper_Padded, NchwcKernelBenchPath::WrapperPadded, 4, 1, 54, 1, 3, 3)->UseRealTime(); + +#if !defined(_WIN32) +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC1_DirectAsm_NoPad, NchwcKernelBenchPath::DirectAsmNoPad, 1, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC2_DirectAsm_NoPad, NchwcKernelBenchPath::DirectAsmNoPad, 2, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC3_DirectAsm_NoPad, NchwcKernelBenchPath::DirectAsmNoPad, 3, 0, 56, 0, 3, 3)->UseRealTime(); +BENCHMARK_CAPTURE(NCHWC_KERNEL_ROW, FC4_DirectAsm_NoPad, NchwcKernelBenchPath::DirectAsmNoPad, 4, 0, 56, 0, 3, 3)->UseRealTime(); +#endif + +#endif diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc_kernel.cpp b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc_kernel.cpp new file mode 100644 index 0000000000000..6c541b5c71d8c --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc_kernel.cpp @@ -0,0 +1,539 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" + +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + +#include +#include + +#include "../../../core/mlas/lib/mlasi.h" +#include "../../../core/mlas/lib/sconv_nchwc_kernel_neon.h" + +class MlasNchwcConvKernelTest : public MlasTestBase { + private: + static constexpr size_t BlockSize = 16; + + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferFilter; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputCpp; + MatrixGuardBuffer BufferOutputAsm; + MatrixGuardBuffer BufferOutputReference; + + template + static void FillBuffer(float* buffer, size_t count, FillFunc&& fill_func) { + for (size_t i = 0; i < count; ++i) { + buffer[i] = fill_func(i); + } + } + + static void ReferenceKernel(const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t FilterCount, + size_t OutputStride, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags) { + const size_t stride_width_elements = StrideWidth / sizeof(float); + const size_t dilation_width_elements = DilationWidth / sizeof(float); + const size_t input_width_elements = InputWidth / sizeof(float); + const size_t dilated_input_width_elements = DilatedInputWidth / sizeof(float); + const size_t output_stride_elements = OutputStride / sizeof(float); + const size_t filter_stride_elements = KernelHeight * KernelWidth * BlockSize * BlockSize; + const size_t total_output_count = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t filter_set = 0; filter_set < FilterCount; ++filter_set) { + const float* filter_block = Filter + filter_set * filter_stride_elements; + float* output_block_base = Output + filter_set * output_stride_elements; + const float* bias_block = Bias + filter_set * BlockSize; + + for (size_t output_idx = 0; output_idx < total_output_count; ++output_idx) { + float accumulator[BlockSize]{}; + + for (size_t kh = 0; kh < KernelHeight; ++kh) { + for (size_t kw = 0; kw < KernelWidth; ++kw) { + const float* input_base = Input + output_idx * stride_width_elements + + kh * dilated_input_width_elements + + kw * dilation_width_elements; + const size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) + + kw * (BlockSize * BlockSize); + const float* input_row_start = InputBase + kh * dilated_input_width_elements; + const float* input_row_end = input_row_start + input_width_elements; + + for (size_t input_lane = 0; input_lane < BlockSize; ++input_lane) { + const float* input_element = input_base + input_lane; + const float input_value = (input_element >= input_row_start && input_element < input_row_end) + ? *input_element + : 0.0f; + const float* filter_row = filter_block + kernel_base_pos + input_lane * BlockSize; + + for (size_t output_lane = 0; output_lane < BlockSize; ++output_lane) { + accumulator[output_lane] += input_value * filter_row[output_lane]; + } + } + } + } + + float* output_block = output_block_base + output_idx * BlockSize; + + for (size_t output_lane = 0; output_lane < BlockSize; ++output_lane) { + float value = accumulator[output_lane]; + + if ((KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0) { + value += output_block[output_lane]; + } + if ((KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0) { + value += bias_block[output_lane]; + } + if ((KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0) { + value = std::max(value, 0.0f); + } + + output_block[output_lane] = value; + } + } + } + } + + void AssertClose(const float* actual, + const float* expected, + size_t count, + const char* actual_label, + const char* expected_label, + size_t FilterCount, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + size_t KernelHeight, + size_t KernelWidth, + unsigned KernelFlags) { + for (size_t i = 0; i < count; ++i) { + ASSERT_TRUE(CloseEnough(actual[i], expected[i])) + << actual_label << " vs " << expected_label + << " @" << i + << " got=" << actual[i] + << " expected=" << expected[i] + << " FilterCount=" << FilterCount + << " LeftPad=" << OutputCountLeftPad + << " OutputCount=" << OutputCount + << " RightPad=" << OutputCountRightPad + << "/KH=" << KernelHeight + << "/KW=" << KernelWidth + << "/Flags=" << KernelFlags; + } + } + + void TestKernel(size_t OutputCount, + size_t KernelHeight, + size_t KernelWidth, + unsigned KernelFlags, + size_t FilterCount = 1, + size_t OutputCountLeftPad = 0, + size_t OutputCountRightPad = 0) { + std::fprintf(stderr, + "Start case FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + + const size_t InputWidth = OutputCount + KernelWidth - 1; + const size_t TotalInputWidth = OutputCountLeftPad + InputWidth + OutputCountRightPad; + const size_t InputElements = KernelHeight * TotalInputWidth * BlockSize; + const size_t FilterElementsPerBlock = KernelHeight * KernelWidth * BlockSize * BlockSize; + const size_t FilterElements = FilterCount * FilterElementsPerBlock; + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + const size_t OutputStrideElements = TotalOutputCount * BlockSize; + const size_t OutputElements = FilterCount * OutputStrideElements; + const size_t BiasElements = FilterCount * BlockSize; + + float* InputStorage = BufferInput.GetFilledBuffer(InputElements, [](float* start, size_t count) { + FillBuffer(start, count, [](size_t i) { + return float((int(i * 7 % 23) - 11)) / 8.0f; + }); + }); + float* Filter = BufferFilter.GetFilledBuffer(FilterElements, [](float* start, size_t count) { + FillBuffer(start, count, [](size_t i) { + return float((int(i * 5 % 29) - 14)) / 9.0f; + }); + }); + float* Bias = BufferBias.GetFilledBuffer(BiasElements, [](float* start, size_t count) { + FillBuffer(start, count, [](size_t i) { + return float((int(i % 9) - 4)) / 8.0f; + }); + }); + float* Output = BufferOutput.GetFilledBuffer(OutputElements, [](float* start, size_t count) { + FillBuffer(start, count, [](size_t i) { + return float((int(i * 3 % 17) - 8)) / 7.0f; + }); + }); + float* OutputCpp = BufferOutputCpp.GetFilledBuffer(OutputElements, [](float* start, size_t count) { + FillBuffer(start, count, [](size_t i) { + return float((int(i * 3 % 17) - 8)) / 7.0f; + }); + }); + float* OutputAsm = BufferOutputAsm.GetFilledBuffer(OutputElements, [](float* start, size_t count) { + FillBuffer(start, count, [](size_t i) { + return float((int(i * 3 % 17) - 8)) / 7.0f; + }); + }); + float* OutputReference = BufferOutputReference.GetFilledBuffer(OutputElements, [](float* start, size_t count) { + FillBuffer(start, count, [](size_t i) { + return float((int(i * 3 % 17) - 8)) / 7.0f; + }); + }); + + const size_t StrideWidthBytes = BlockSize * sizeof(float); + const size_t DilationWidthBytes = BlockSize * sizeof(float); + const size_t StrideWidthElements = StrideWidthBytes / sizeof(float); + const size_t InputWidthBytes = BlockSize * InputWidth * sizeof(float); + const size_t DilatedInputWidthBytes = BlockSize * TotalInputWidth * sizeof(float); + const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; + const size_t FilterStrideBytes = FilterElementsPerBlock * sizeof(float); + const size_t OutputStrideBytes = OutputStrideElements * sizeof(float); + const float* InputBase = InputStorage + OutputCountLeftPad * StrideWidthElements; + const float* Input = InputBase - OutputCountLeftPad * StrideWidthElements; + + ReferenceKernel(Input, + Filter, + OutputReference, + StrideWidthBytes, + DilationWidthBytes, + KernelHeight, + KernelWidth, + InputBase, + InputWidthBytes, + DilatedInputWidthBytes, + FilterCount, + OutputStrideBytes, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags); + + std::fprintf(stderr, + "Completed reference FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + + MlasConvNchwcFloatKernelNeonCpp(Input, + Filter, + OutputCpp, + StrideWidthBytes, + DilationWidthBytes, + FilterCount, + InputStrideBytes, + FilterStrideBytes, + OutputStrideBytes, + KernelHeight, + KernelWidth, + InputBase, + InputWidthBytes, + DilatedInputWidthBytes, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags); + + std::fprintf(stderr, + "Completed cpp FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + + MlasConvNchwcFloatKernelNeon(Input, + Filter, + Output, + StrideWidthBytes, + DilationWidthBytes, + FilterCount, + InputStrideBytes, + FilterStrideBytes, + OutputStrideBytes, + KernelHeight, + KernelWidth, + InputBase, + InputWidthBytes, + DilatedInputWidthBytes, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags); + + std::fprintf(stderr, + "Completed wrapper FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + +#if !defined(_WIN32) + if (OutputCountLeftPad == 0 && OutputCountRightPad == 0) { + std::fprintf(stderr, + "Calling asm FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + + MlasConvNchwcFloatKernelNeonAsm(Input, + Filter, + OutputAsm, + StrideWidthBytes, + DilationWidthBytes, + FilterCount, + InputStrideBytes, + FilterStrideBytes, + OutputStrideBytes, + KernelHeight, + KernelWidth, + InputBase, + InputWidthBytes, + DilatedInputWidthBytes, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags); + + std::fprintf(stderr, + "Completed asm FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + } else { + std::fprintf(stderr, + "Skipping direct asm FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + } +#endif + + AssertClose(OutputCpp, OutputReference, OutputElements, "cpp", "reference", FilterCount, OutputCountLeftPad, OutputCount, OutputCountRightPad, KernelHeight, KernelWidth, KernelFlags); + std::fprintf(stderr, + "Completed cpp-vs-reference FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + AssertClose(Output, OutputCpp, OutputElements, "wrapper", "cpp", FilterCount, OutputCountLeftPad, OutputCount, OutputCountRightPad, KernelHeight, KernelWidth, KernelFlags); + std::fprintf(stderr, + "Completed wrapper-vs-cpp FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); +#if !defined(_WIN32) + if (OutputCountLeftPad == 0 && OutputCountRightPad == 0) { + AssertClose(OutputAsm, OutputCpp, OutputElements, "asm", "cpp", FilterCount, OutputCountLeftPad, OutputCount, OutputCountRightPad, KernelHeight, KernelWidth, KernelFlags); + std::fprintf(stderr, + "Completed asm-vs-cpp FilterCount=%zu/LeftPad=%zu/OutputCount=%zu/RightPad=%zu/KH=%zu/KW=%zu/Flags=%u\n", + FilterCount, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + KernelHeight, + KernelWidth, + KernelFlags); + std::fflush(stderr); + } +#endif + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Conv2dNchwcKernel"); + return suite_name.c_str(); + } + + void ExecuteShort() override { + if (MlasNchwcGetBlockSize() != BlockSize) { + return; + } + + // TestKernel(OutputCount, KernelHeight, KernelWidth, KernelFlags, FilterCount) + + // Single-output microkernel coverage. + TestKernel(1, 1, 1, 0); + TestKernel(1, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION); + + // Two-output fast path with and without bias. + TestKernel(2, 1, 1, 0); + TestKernel(2, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION); + + // Single-output multi-row and multi-column coverage. + TestKernel(1, 3, 3, 0); + TestKernel(1, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION); + + // Two-output fast path on a larger spatial kernel. + TestKernel(2, 3, 3, 0); + TestKernel(2, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION); + + // Two-output postprocess coverage: accumulate only and accumulate+bias+ReLU. + TestKernel(2, 3, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT); + TestKernel(2, 3, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION); + + // Three outputs exercise the two-output fast path followed by the one-output tail. + TestKernel(3, 3, 3, 0); + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION); + + // Three outputs on a 1x3 kernel with full postprocess coverage. + TestKernel(3, 1, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION); + + // FC2 single-output coverage. + TestKernel(1, 1, 1, 0, 2); + TestKernel(1, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 2); + TestKernel(1, 3, 3, 0, 2); + TestKernel(1, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 2); + + // FC2 two-output fast path. + TestKernel(2, 1, 1, 0, 2); + TestKernel(2, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 2); + TestKernel(2, 3, 3, 0, 2); + TestKernel(2, 3, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, 2); + + // FC2 tail coverage: two-output fast path followed by one-output tail. + TestKernel(3, 3, 3, 0, 2); + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 2); + + // FC3 single-output coverage. + TestKernel(1, 1, 1, 0, 3); + TestKernel(1, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 3); + TestKernel(1, 3, 3, 0, 3); + TestKernel(1, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 3); + + // FC3 multi-output and postprocess coverage. + TestKernel(2, 1, 1, 0, 3); + TestKernel(2, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 3); + TestKernel(2, 3, 3, 0, 3); + TestKernel(2, 3, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, + 3); + + // FC3 tail coverage. + TestKernel(3, 3, 3, 0, 3); + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 3); + + // FC4 single-output coverage. + TestKernel(1, 1, 1, 0, 4); + TestKernel(1, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 4); + TestKernel(1, 3, 3, 0, 4); + TestKernel(1, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 4); + + // FC4 multi-output and postprocess coverage. + TestKernel(2, 1, 1, 0, 4); + TestKernel(2, 1, 1, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 4); + TestKernel(2, 3, 3, 0, 4); + TestKernel(2, 3, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, + 4); + + // FC4 tail coverage. + TestKernel(3, 3, 3, 0, 4); + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 4); + + // Padded wrapper coverage: C++ edges with asm interior. + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 1, 1, 1); + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, + 2, + 1, + 1); + + // Asymmetric padded wrapper coverage. + TestKernel(4, 3, 3, 0, 1, 1, 0); + TestKernel(4, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 1, 0, 1); + TestKernel(4, 1, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, + 2, + 1, + 0); + TestKernel(4, 1, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, + 2, + 0, + 1); + + // Wider padded rows make the interior asm span non-trivial. + TestKernel(8, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 1, 1, 1); + TestKernel(8, 3, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, + 2, + 1, + 1); + + // FC3 padded coverage uses C++ edges with asm on the interior span. + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 3, 1, 1); + TestKernel(4, 1, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, + 3, + 1, + 0); + + // FC4 padded coverage uses C++ edges with asm on the interior span. + TestKernel(3, 3, 3, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION, 4, 1, 1); + TestKernel(4, 1, 3, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT | MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION | MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION, + 4, + 1, + 0); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + return (MlasNchwcGetBlockSize() > 1 && is_short_execute) + ? MlasDirectShortExecuteTests::RegisterShortExecute() + : 0; +}); + +#endif diff --git a/onnxruntime/test/util/file_util.cc b/onnxruntime/test/util/file_util.cc index bb7d3a472477d..6441794258af6 100644 --- a/onnxruntime/test/util/file_util.cc +++ b/onnxruntime/test/util/file_util.cc @@ -1,12 +1,13 @@ #include "file_util.h" -#include #include #ifdef _WIN32 #include #include #include +#else +#include #endif namespace onnxruntime { @@ -26,9 +27,9 @@ PathString GetSharedLibraryFileName(const PathString& base_library_name) { void DeleteFileFromDisk(const ORTCHAR_T* path) { #ifdef _WIN32 - ASSERT_EQ(TRUE, DeleteFileW(path)); + ORT_ENFORCE(DeleteFileW(path) == TRUE, "DeleteFileW failed for path."); #else - ASSERT_EQ(0, unlink(path)); + ORT_ENFORCE(unlink(path) == 0, "unlink failed for path."); #endif } void CreateTestFile(int& out, std::basic_string& filename_template) { @@ -37,7 +38,7 @@ void CreateTestFile(int& out, std::basic_string& filename_template) { ORTCHAR_T* filename = const_cast(filename_template.c_str()); #ifdef _WIN32 - ASSERT_EQ(0, _wmktemp_s(filename, filename_template.length() + 1)); + ORT_ENFORCE(_wmktemp_s(filename, filename_template.length() + 1) == 0, "_wmktemp_s failed."); int fd; int err = _wsopen_s(&fd, filename, _O_CREAT | _O_EXCL | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYRW, _S_IREAD | _S_IWRITE); if (err != 0) @@ -56,9 +57,9 @@ void CreateTestFile(FILE*& out, std::basic_string& filename_template) ORTCHAR_T* filename = const_cast(filename_template.c_str()); #ifdef _WIN32 - ASSERT_EQ(0, _wmktemp_s(filename, filename_template.length() + 1)); + ORT_ENFORCE(_wmktemp_s(filename, filename_template.length() + 1) == 0, "_wmktemp_s failed."); FILE* fp = nullptr; - ASSERT_EQ(0, _wfopen_s(&fp, filename, ORT_TSTR("wb"))); + ORT_ENFORCE(_wfopen_s(&fp, filename, ORT_TSTR("wb")) == 0, "_wfopen_s failed."); #else int fd = mkstemp(filename); if (fd < 0) { diff --git a/onnxruntime/test/util/temp_dir.cc b/onnxruntime/test/util/temp_dir.cc index cbd0bdc7febd6..2ee8a9aaa4ed7 100644 --- a/onnxruntime/test/util/temp_dir.cc +++ b/onnxruntime/test/util/temp_dir.cc @@ -3,8 +3,6 @@ #include "test/util/include/temp_dir.h" -#include "gtest/gtest.h" - #include "core/platform/env.h" namespace onnxruntime { @@ -12,10 +10,8 @@ namespace test { namespace { void CreateOrDeleteDirectory(const PathString& path, bool create, bool throw_on_fail = true) { const auto status = create ? Env::Default().CreateFolder(path) : Env::Default().DeleteFolder(path); - EXPECT_TRUE(status.IsOK()) << "Failed to " << (create ? "create" : "delete") << "temporary directory " << path; - if (throw_on_fail) { - ORT_ENFORCE(status.IsOK()); + ORT_ENFORCE(status.IsOK(), "Failed to ", (create ? "create" : "delete"), " temporary directory."); } } } // namespace @@ -26,8 +22,7 @@ TemporaryDirectory::TemporaryDirectory(const PathString& path, bool delete_if_ex const bool exists = Env::Default().FolderExists(path_); if (exists) { if (!delete_if_exists) { - EXPECT_FALSE(exists) << "Temporary directory " << path_ << " already exists."; - ORT_ENFORCE(!exists); + ORT_THROW("Temporary directory already exists."); } CreateOrDeleteDirectory(path_, /* create */ false);