-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[MLAS] Integrate KleidiAI BF16 SME2 Kernel Through Mlas SBGEMM Path #26773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLAS] Integrate KleidiAI BF16 SME2 Kernel Through Mlas SBGEMM Path #26773
Conversation
Signed-off-by: Patryk Kaiser <[email protected]>
|
@microsoft-github-policy-service agree company="Arm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR integrates the Arm® KleidiAI™ SME2 BF16 kernel into the MLAS SBGEMM (single-precision to bfloat16 GEMM) path. The integration provides performance improvements for bfloat16 matrix multiplication operations on ARM devices with SME2 support.
Changes:
- Added new
sbgemm_kleidiai.cppimplementation with KleidiAI BF16 SME2 kernel - Introduced
BIsPackedflag toMLAS_SBGEMM_DATA_PARAMSto track pre-packed matrix B state - Added override mechanism in SBGEMM path for KleidiAI kernels on SME2-enabled platforms
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp | New implementation of SBGEMM using KleidiAI BF16 SME2 kernel |
| onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h | Added function declarations for SBGEMM KleidiAI overrides |
| onnxruntime/core/mlas/lib/kai_ukernel_interface.h | Added SBGEMM ukernel interface declaration |
| onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp | Added SBGEMM ukernel instantiation for SME2 |
| onnxruntime/core/mlas/lib/mlasi.h | Added typedef declarations for SBGEMM override functions |
| onnxruntime/core/mlas/lib/sbgemm.h | Added override mechanism to call KleidiAI SBGEMM functions |
| onnxruntime/core/mlas/lib/platform.cpp | Registered KleidiAI SBGEMM overrides for SME2-enabled platforms |
| onnxruntime/core/mlas/inc/mlas.h | Added BIsPacked field to MLAS_SBGEMM_DATA_PARAMS struct |
| onnxruntime/core/providers/cpu/math/matmul.cc | Set BIsPacked flag when using pre-packed matrix B |
| onnxruntime/test/mlas/unittest/test_sbgemm.h | Updated tests to initialize and set BIsPacked flag |
| cmake/onnxruntime_mlas.cmake | Added sbgemm_kleidiai.cpp to build system |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) | ||
| { | ||
| #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) | ||
| if (GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent spacing: there are two spaces after "nullptr" on line 337, while similar code on line 353 uses a single space. Consider using consistent spacing for code uniformity.
| if (GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && | |
| if (GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && |
| MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) | ||
| { | ||
| #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) | ||
| if(GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after "if" on line 353. Should be "if (" instead of "if(" for consistency with C++ style conventions.
| if(GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && | |
| if (GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && |
| this->MlasConvOverride = ArmKleidiAI::MlasConv; | ||
| #if defined(__aarch64__) && defined(__linux__) | ||
| // Currently only an SME2 variant of SBGEMM exists | ||
| if(ArmKleidiAI::UseSME2){ |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after "if" on line 612. Should be "if (" instead of "if(" for consistency with C++ style conventions.
| if(ArmKleidiAI::UseSME2){ | |
| if (ArmKleidiAI::UseSME2){ |
| LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr); | ||
|
|
||
| size_t lhs_resize = 0; | ||
| if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after "if" on line 220. Should be "if (" instead of "if(" for consistency with C++ style conventions.
| if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) | |
| if (mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) |
| std::vector<float> bias_zero; | ||
| std::vector<std::byte> rhs_packed; | ||
| std::vector<std::byte> lhs_packed; | ||
| std::vector<float> gemv_lhs_row_tmp; |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gemv_lhs_row_tmp field is declared but never used in this file. Consider removing it if it's not needed, or add a comment explaining its purpose for future use.
| std::vector<float> gemv_lhs_row_tmp; |
| if (GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr) { | ||
| size_t bytes_required; | ||
| bytes_required = GetMlasPlatform().MlasSBGemmPackBSizeOverride(N, K); | ||
| if (bytes_required != 0){// If ArmKleidiAI::MlasSBGemmPackBSize ran to completion |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent comment spacing: missing space after "//" on line 310. Should be "// If" instead of "//If" for consistency with the codebase style.
| if (bytes_required != 0){// If ArmKleidiAI::MlasSBGemmPackBSize ran to completion | |
| if (bytes_required != 0){ // If ArmKleidiAI::MlasSBGemmPackBSize ran to completion |
Description
This PR integrates Arm® KleidiAI™ SME2 BF16 kernel through MLAS SBGEMM path.
Rework of #24346
Motivation and Context
This kernel provides performance improvements on SME-enabled devices.