mlas/arm64: add BF16 fast-math conv kernels for NCHW/NCHWc paths#27878
mlas/arm64: add BF16 fast-math conv kernels for NCHW/NCHWc paths#27878hariharans29 merged 11 commits intomicrosoft:mainfrom
Conversation
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Pull request overview
This PR adds Arm64 BF16 fast-math convolution support to MLAS by introducing new AArch64 BF16 assembly kernels and wiring them into the MLAS platform dispatch so BF16 fast-math can accelerate NCHW, depthwise 3x3 NCHWc, and pointwise 1x1 NCHWc convolution paths on Linux/AArch64.
Changes:
- Add new AArch64 BF16 asm micro-kernels for direct NCHW conv, depthwise 3x3 NCHWc conv, and pointwise 1x1 NCHWc conv (BFMMLA-based).
- Route MLAS conv dispatch to BF16 kernels when
WorkBlock->UseBf16is enabled (Linux/AArch64), including keeping pointwise accumulation on the BF16 path. - Update CMake to compile the new asm sources (and BF16-dependent C++ source) with
-march=armv8.2-a+bf16.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/mlas/lib/snchwc.cpp | Select BF16 kernels for NCHW + depthwise paths when UseBf16 is set (Linux/AArch64). |
| onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp | Implements BF16 NCHW kernel wrapper, BF16 depthwise wrapper + asm hot path dispatch, and BF16 pointwise asm path + post-processing helpers. |
| onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp | Tries the new BF16 pointwise asm path first before falling back to SBGEMM. |
| onnxruntime/core/mlas/lib/platform.cpp | Wires new BF16 conv function pointers into MLAS_PLATFORM on Linux/AArch64. |
| onnxruntime/core/mlas/lib/mlasi.h | Adds declarations and platform struct members for new BF16 conv kernels. |
| onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeonBf16.S | New BF16 pointwise (1x1) NCHWc asm kernels + filter packing. |
| onnxruntime/core/mlas/lib/aarch64/SconvKernelNeonBf16.S | New BF16 direct NCHW conv asm kernel + filter packing + BF16 epilogue helper. |
| onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeonBf16.S | New BF16 depthwise 3x3 NCHWc asm kernels + dispatch/epilogue hook. |
| cmake/onnxruntime_mlas.cmake | Adds new asm sources and applies BF16 -march flags. |
Comments suppressed due to low confidence (1)
onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp:834
- BiasMask/ReluMask are computed as
-(KernelFlags & FLAG)even though the flags are 0x2/0x4. This yields -2/-4 masks and will corrupt outputs when applying bias and/or ReLU via bitwise operations. Please build these masks from boolean predicates (0/-1) before using MlasAnd*/MlasBlend*.
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
const float32x4_t AccumulateMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT)));
const float32x4_t BiasMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION)));
const float32x4_t ReluMask = vreinterpretq_f32_s32(MlasBroadcastInt32x4(-(KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION)));
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Thank you for your contribution! Excited to see this upstreamed 🚀 |
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
|
The commit 5e0315a should hopefully fix failures on macOS from previous run. |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
|
The previous macOS failure was because the BF16 conv asm files were always added in |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
|
The commit 8a69461 fixes error that happens on Linux/Arm64 builds where the compiler says AArch64, but NCHWC Arm64 BF16 conv is not enabled |
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Thanks. LGTM overall. Will just seek Copilot's opinion one more time. |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Description
Add Arm64 BF16 fast-math convolution support in MLAS:
This change adds new AArch64 BF16 asm kernels, wires them into MLAS platform dispatch, keeps accumulated pointwise batches on the custom BF16 path instead of falling back to generic SBGEMM, and adds the required BF16 build flags.
The new paths are only used when Arm64 BF16 fast-math is enabled via the existing session option. Baseline FP32 behavior is unchanged.
Performance
Individual convolution improvements when running on
c8gAWS instance where in columns base is FP32 execution, fast-math when enabled without this PR and PR is fast-math with this change:When running the full models the performance improvements are on
c8g(AWS Graviton 4) andStandard_D32plds_v6(Azure Cobalt-100) when running MobileNet v2.7 with 8 threads are:c8gStandard_D32plds_v6(cc: @Rohanjames1997 @snadampal)