From 3e4ec6c03c2a8732d0cae3bab959c37154f1f434 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 10:19:51 -0700 Subject: [PATCH 1/2] Enable BF16 KAI SBGemm on NCHWc ARM builds --- onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp | 7 ++++++- onnxruntime/core/mlas/lib/sbgemm.h | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp index f3a17f30be166..28d0a42f7841e 100644 --- a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -45,7 +45,12 @@ MlasConvPointwiseBf16KernelNeon( const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; - mlas_backend_kernel_selector_config.use_kleidiai = ((KernelFlags & MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI) != 0); + + // TODO(hasesh/ARM KleidiAI team): Study the impact of using the KleidiAI SBGEMM kernel for this convolution kernel + // and enable it if things look okay. + // Even if re-enabled, honor user override to disable KleidiAI usage if specified. + // That is, mlas_backend_kernel_selector_config.use_kleidiai = ((KernelFlags & MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI) != 0); + mlas_backend_kernel_selector_config.use_kleidiai = false; const size_t StrideWidthElements = StrideWidth / sizeof(float); const size_t InputStrideElements = InputStride / sizeof(float); diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index 559c8b48e78b6..35520c7a4226a 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -311,7 +311,7 @@ MlasSBGemmPackBSize( // // Compute the number of bytes required to hold the packed buffer. // -#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) +#if defined(USE_KLEIDIAI) if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr && TransA == CBLAS_TRANSPOSE::CblasNoTrans && @@ -359,7 +359,7 @@ MlasSBGemmConvertPackB( const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig ) { -#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) +#if defined(USE_KLEIDIAI) if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && TransA == CBLAS_TRANSPOSE::CblasNoTrans && @@ -393,7 +393,7 @@ MlasSBGemmBatch( const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig ) { -#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) +#if defined(USE_KLEIDIAI) if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && TransA == CBLAS_TRANSPOSE::CblasNoTrans && From 0539673338a2c1dd321b3b8a13632cab2b283fce Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 17 Mar 2026 10:33:42 -0700 Subject: [PATCH 2/2] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp index 28d0a42f7841e..d513791c695f5 100644 --- a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -46,7 +46,7 @@ MlasConvPointwiseBf16KernelNeon( MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; - // TODO(hasesh/ARM KleidiAI team): Study the impact of using the KleidiAI SBGEMM kernel for this convolution kernel + // TODO(hasesh): With the ARM KleidiAI team, study the impact of using the KleidiAI SBGEMM kernel for this convolution kernel // and enable it if things look okay. // Even if re-enabled, honor user override to disable KleidiAI usage if specified. // That is, mlas_backend_kernel_selector_config.use_kleidiai = ((KernelFlags & MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI) != 0);