Skip to content

mlas/arm64: add BF16 fast-math conv kernels for NCHW/NCHWc paths#27878

Merged
hariharans29 merged 11 commits intomicrosoft:mainfrom
milpuz01:conv_fast_math
Apr 9, 2026
Merged

mlas/arm64: add BF16 fast-math conv kernels for NCHW/NCHWc paths#27878
hariharans29 merged 11 commits intomicrosoft:mainfrom
milpuz01:conv_fast_math

Conversation

@milpuz01
Copy link
Copy Markdown
Contributor

Description

Add Arm64 BF16 fast-math convolution support in MLAS:

  • direct NCHW conv
  • depthwise 3x3 NCHWc conv
  • pointwise 1x1 NCHWc conv

This change adds new AArch64 BF16 asm kernels, wires them into MLAS platform dispatch, keeps accumulated pointwise batches on the custom BF16 path instead of falling back to generic SBGEMM, and adds the required BF16 build flags.

The new paths are only used when Arm64 BF16 fast-math is enabled via the existing session option. Baseline FP32 behavior is unchanged.

Performance

Individual convolution improvements when running on c8g AWS instance where in columns base is FP32 execution, fast-math when enabled without this PR and PR is fast-math with this change:

Type Shape fast-math vs base PR w/ fast-math vs base PR w/ fast-math vs fast-math
depthwise N1 IC32 OC32 H112xW112->112x112 K3x3 S1x1 D1x1 P1/1/1/1 G32 0.991x 1.047x 1.057x
depthwise N1 IC96 OC96 H112xW112->56x56 K3x3 S2x2 D1x1 P1/1/1/1 G96 1.015x 1.015x 1.000x
depthwise N1 IC144 OC144 H56xW56->28x28 K3x3 S2x2 D1x1 P1/1/1/1 G144 1.020x 1.004x 0.984x
depthwise N1 IC144 OC144 H56xW56->56x56 K3x3 S1x1 D1x1 P1/1/1/1 G144 1.034x 1.138x 1.101x
depthwise N1 IC192 OC192 H28xW28->28x28 K3x3 S1x1 D1x1 P1/1/1/1 G192 0.997x 1.033x 1.037x
depthwise N1 IC384 OC384 H28xW28->14x14 K3x3 S2x2 D1x1 P1/1/1/1 G384 1.016x 1.021x 1.005x
depthwise N1 IC384 OC384 H28xW28->28x28 K3x3 S1x1 D1x1 P1/1/1/1 G384 1.011x 1.090x 1.077x
depthwise N1 IC576 OC576 H14xW14->7x7 K3x3 S2x2 D1x1 P1/1/1/1 G576 1.029x 0.995x 0.967x
depthwise N1 IC576 OC576 H14xW14->14x14 K3x3 S1x1 D1x1 P1/1/1/1 G576 1.025x 1.006x 0.982x
depthwise N1 IC960 OC960 H7xW7->7x7 K3x3 S1x1 D1x1 P1/1/1/1 G960 1.002x 0.941x 0.939x
nchw N1 IC3 OC32 H224xW224->112x112 K3x3 S2x2 D1x1 P1/1/1/1 G1 1.001x 1.058x 1.058x
pointwise N1 IC16 OC96 H112xW112->112x112 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.213x 1.328x 1.095x
pointwise N1 IC32 OC16 H112xW112->112x112 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.020x 1.019x 0.998x
pointwise N1 IC32 OC32 H112xW112->112x112 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.118x 1.196x 1.069x
pointwise N1 IC32 OC144 H56xW56->56x56 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.220x 1.528x 1.252x
pointwise N1 IC32 OC192 H28xW28->28x28 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.199x 1.418x 1.183x
pointwise N1 IC64 OC384 H28xW28->28x28 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.294x 1.938x 1.497x
pointwise N1 IC96 OC32 H56xW56->56x56 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.080x 1.426x 1.320x
pointwise N1 IC96 OC576 H14xW14->14x14 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.280x 1.961x 1.532x
pointwise N1 IC144 OC32 H28xW28->28x28 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.132x 1.351x 1.193x
pointwise N1 IC144 OC32 H56xW56->56x56 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.073x 1.374x 1.281x
pointwise N1 IC160 OC960 H7xW7->7x7 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.133x 1.744x 1.539x
pointwise N1 IC192 OC32 H28xW28->28x28 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.166x 1.411x 1.210x
pointwise N1 IC192 OC64 H28xW28->28x28 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.212x 1.763x 1.454x
pointwise N1 IC320 OC1280 H7xW7->7x7 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.136x 2.059x 1.812x
pointwise N1 IC384 OC64 H28xW28->28x28 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.256x 1.904x 1.516x
pointwise N1 IC384 OC96 H14xW14->14x14 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.206x 1.929x 1.600x
pointwise N1 IC576 OC96 H14xW14->14x14 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.250x 2.055x 1.644x
pointwise N1 IC576 OC160 H7xW7->7x7 K1x1 S1x1 D1x1 P0/0/0/0 G1 0.902x 1.423x 1.577x
pointwise N1 IC960 OC160 H7xW7->7x7 K1x1 S1x1 D1x1 P0/0/0/0 G1 0.915x 1.527x 1.668x
pointwise N1 IC960 OC320 H7xW7->7x7 K1x1 S1x1 D1x1 P0/0/0/0 G1 1.020x 1.756x 1.723x
pointwise N1 IC1280 OC1008 H1xW1->1x1 K1x1 S1x1 D1x1 P0/0/0/0 G1 0.747x 1.149x 1.538x

When running the full models the performance improvements are on c8g (AWS Graviton 4) and Standard_D32plds_v6 (Azure Cobalt-100) when running MobileNet v2.7 with 8 threads are:

Instance PR w/ fast-math vs base PR w/ fast-math vs fast-mat
c8g 1.892x 1.647x
Standard_D32plds_v6 2.884x 1.692x

(cc: @Rohanjames1997 @snadampal)

Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Comment thread onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp
@hariharans29 hariharans29 requested a review from Copilot March 27, 2026 18:48
@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Arm64 BF16 fast-math convolution support to MLAS by introducing new AArch64 BF16 assembly kernels and wiring them into the MLAS platform dispatch so BF16 fast-math can accelerate NCHW, depthwise 3x3 NCHWc, and pointwise 1x1 NCHWc convolution paths on Linux/AArch64.

Changes:

  • Add new AArch64 BF16 asm micro-kernels for direct NCHW conv, depthwise 3x3 NCHWc conv, and pointwise 1x1 NCHWc conv (BFMMLA-based).
  • Route MLAS conv dispatch to BF16 kernels when WorkBlock->UseBf16 is enabled (Linux/AArch64), including keeping pointwise accumulation on the BF16 path.
  • Update CMake to compile the new asm sources (and BF16-dependent C++ source) with -march=armv8.2-a+bf16.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
onnxruntime/core/mlas/lib/snchwc.cpp Select BF16 kernels for NCHW + depthwise paths when UseBf16 is set (Linux/AArch64).
onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp Implements BF16 NCHW kernel wrapper, BF16 depthwise wrapper + asm hot path dispatch, and BF16 pointwise asm path + post-processing helpers.
onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp Tries the new BF16 pointwise asm path first before falling back to SBGEMM.
onnxruntime/core/mlas/lib/platform.cpp Wires new BF16 conv function pointers into MLAS_PLATFORM on Linux/AArch64.
onnxruntime/core/mlas/lib/mlasi.h Adds declarations and platform struct members for new BF16 conv kernels.
onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeonBf16.S New BF16 pointwise (1x1) NCHWc asm kernels + filter packing.
onnxruntime/core/mlas/lib/aarch64/SconvKernelNeonBf16.S New BF16 direct NCHW conv asm kernel + filter packing + BF16 epilogue helper.
onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeonBf16.S New BF16 depthwise 3x3 NCHWc asm kernels + dispatch/epilogue hook.
cmake/onnxruntime_mlas.cmake Adds new asm sources and applies BF16 -march flags.
Comments suppressed due to low confidence (1)

onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp:834

  • BiasMask/ReluMask are computed as -(KernelFlags & FLAG) even though the flags are 0x2/0x4. This yields -2/-4 masks and will corrupt outputs when applying bias and/or ReLU via bitwise operations. Please build these masks from boolean predicates (0/-1) before using MlasAnd*/MlasBlend*.
    const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
    const float32x4_t AccumulateMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT)));
    const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION)));
    const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION)));

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp Outdated
@aviralagrawal
Copy link
Copy Markdown

Thank you for your contribution! Excited to see this upstreamed 🚀

Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
@milpuz01
Copy link
Copy Markdown
Contributor Author

The commit 5e0315a should hopefully fix failures on macOS from previous run.

@hariharans29 hariharans29 requested a review from Copilot March 31, 2026 02:28
@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeonBf16.S Outdated
Comment thread onnxruntime/core/mlas/lib/snchwc.cpp
Comment thread onnxruntime/core/mlas/lib/snchwc.cpp
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
@milpuz01
Copy link
Copy Markdown
Contributor Author

The previous macOS failure was because the BF16 conv asm files were always added in onnxruntime_mlas.cmake, but the BF16 compile flags were only applied under if(NOT APPLE). So Apple Clang tried to assemble bfcvtn without BF16 enabled. The change in 68108bd make BF16 conv asm conditional on actual toolchain support.

@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
@milpuz01
Copy link
Copy Markdown
Contributor Author

The commit 8a69461 fixes error that happens on Linux/Arm64 builds where the compiler says AArch64, but NCHWC Arm64 BF16 conv is not enabled

Comment thread onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp Outdated
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread cmake/onnxruntime_mlas.cmake Outdated
Comment thread onnxruntime/core/mlas/lib/platform.cpp Outdated
milpuz01 added 2 commits April 7, 2026 13:38
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Comment thread onnxruntime/core/mlas/lib/mlasi.h
Comment thread onnxruntime/core/mlas/lib/mlasi.h
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread cmake/onnxruntime_mlas.cmake Outdated
Comment thread cmake/onnxruntime_mlas.cmake
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@hariharans29
Copy link
Copy Markdown
Member

Thanks. LGTM overall. Will just seek Copilot's opinion one more time.

@hariharans29 hariharans29 requested a review from Copilot April 9, 2026 18:00
@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants