diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp index f3a17f30be166..d513791c695f5 100644 --- a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -45,7 +45,12 @@ MlasConvPointwiseBf16KernelNeon( const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; - mlas_backend_kernel_selector_config.use_kleidiai = ((KernelFlags & MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI) != 0); + + // TODO(hasesh): With the ARM KleidiAI team, study the impact of using the KleidiAI SBGEMM kernel for this convolution kernel + // and enable it if things look okay. + // Even if re-enabled, honor user override to disable KleidiAI usage if specified. + // That is, mlas_backend_kernel_selector_config.use_kleidiai = ((KernelFlags & MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI) != 0); + mlas_backend_kernel_selector_config.use_kleidiai = false; const size_t StrideWidthElements = StrideWidth / sizeof(float); const size_t InputStrideElements = InputStride / sizeof(float); diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index 559c8b48e78b6..35520c7a4226a 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -311,7 +311,7 @@ MlasSBGemmPackBSize( // // Compute the number of bytes required to hold the packed buffer. // -#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) +#if defined(USE_KLEIDIAI) if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr && TransA == CBLAS_TRANSPOSE::CblasNoTrans && @@ -359,7 +359,7 @@ MlasSBGemmConvertPackB( const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig ) { -#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) +#if defined(USE_KLEIDIAI) if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && TransA == CBLAS_TRANSPOSE::CblasNoTrans && @@ -393,7 +393,7 @@ MlasSBGemmBatch( const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig ) { -#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) +#if defined(USE_KLEIDIAI) if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && TransA == CBLAS_TRANSPOSE::CblasNoTrans &&