From 39b7e055fd32252969aa7d1028e2779dd483da70 Mon Sep 17 00:00:00 2001 From: Patryk Kaiser Date: Fri, 15 Aug 2025 14:26:17 +0100 Subject: [PATCH] Integrate SME1 SGEMM KleidiAI kernels Signed-off-by: Patryk Kaiser --- cmake/deps.txt | 2 +- onnxruntime/core/common/cpuid_info.cc | 2 + onnxruntime/core/common/cpuid_info.h | 2 + .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 3 + .../core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 287 +++++++++++------- 5 files changed, 177 insertions(+), 119 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index b9c9ac3ce8ef3..3d419f7fd913b 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a -kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.9.0.tar.gz;a2765979f64efb173a4b8ba4de39dcba9c655786 +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794 diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 6c66047b4b36a..bbc2a07bef5ea 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -192,6 +192,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); has_arm_sme_ = cpuinfo_has_arm_sme(); + has_arm_sme2_ = cpuinfo_has_arm_sme2(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -332,6 +333,7 @@ void CPUIDInfo::ArmAppleInit() { has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); has_arm_sme_ = cpuinfo_has_arm_sme(); + has_arm_sme2_ = cpuinfo_has_arm_sme2(); // Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect // to encounter on Apple platforms. diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index d49eca7e1d60c..74e1bbf19aed9 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -41,6 +41,7 @@ class CPUIDInfo { bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } bool HasArm_SME() const { return has_arm_sme_; } + bool HasArm_SME2() const { return has_arm_sme2_; } uint32_t GetCurrentCoreIdx() const; @@ -162,6 +163,7 @@ class CPUIDInfo { bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; bool has_arm_sme_{false}; + bool has_arm_sme2_{false}; std::string vendor_; uint32_t vendor_id_; diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 11fd78c261834..5136061c4769d 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -15,6 +15,9 @@ #define RESTRICT __restrict__ #endif namespace ArmKleidiAI { +// By default we should try for SME2 first before falling back to SME. +inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); + // // Buffer packing routines. // diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index c579ff1542eb9..ea38f16205a7c 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -8,6 +8,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" @@ -107,7 +108,8 @@ Routine Description: Return Value: - None. + Returns true if the packing operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. --*/ { @@ -116,9 +118,12 @@ Return Value: } if (TransA == CblasNoTrans) { - const size_t nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t nr = UseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); // pass zeroed bias values const std::vector bias(N); @@ -152,6 +157,42 @@ ArmKleidiAI::MlasGemmBatch( size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) +/*++ + +Routine Description: + + This routine performs a batched matrix multiplication (GEMM) operation using KleidiAI kernels. + It handles both packed and unpacked inputs and manages tiling and kernel selection depending on + SME2 availability. If packing is needed, it prepares the required buffers and invokes the + appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. + + The function also applies alpha and beta scaling to the result, supports efficient memcpy + paths where possible, and dispatches tile-level GEMM work using multithreading. + +Arguments: + + TransA - Supplies the transpose operation for matrix A. + + TransB - Supplies the transpose operation for matrix B. + + M - Supplies the number of rows of matrix A and matrix C. + + N - Supplies the number of columns of matrix B and matrix C. + + K - Supplies the number of columns of matrix A and rows of matrix B. + + Data - Supplies a pointer to the MLAS_SGEMM_DATA_PARAMS array containing per-batch input/output pointers and parameters. + + BatchSize - Supplies the number of independent GEMM computations to perform in the batch. + + ThreadPool - Supplies the thread pool to parallelize computation across batches and tiles. + +Return Value: + + Returns true if the GEMM operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ { if (M == 0 || N == 0) { return true; @@ -172,130 +213,134 @@ ArmKleidiAI::MlasGemmBatch( return true; } - if (TransA == CblasNoTrans) { - const size_t mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - - auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - - if (M < m_step && N < n_step && !Data->BIsPacked) { - // Fallback to MLAS - return false; - } - - std::vector KaiPackedData; - KaiPackedData.resize(BatchSize); - - size_t LhsPackedStride = 0; - std::byte* LhsPackedData = nullptr; - - LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); - auto LhsPacked = std::make_unique(LhsPackedStride * BatchSize); - LhsPackedData = LhsPacked.get(); - - std::unique_ptr RhsPacked{nullptr}; - - // It is assumed all B batches require packing or not - if (Data[0].BIsPacked) { - // We have already decided the matmul variant we are using, before having values for M,N,K - MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { - std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); - - kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); - - KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); - KaiPackedData[batch_idx].B = Data[batch_idx].B; - }); - } else { - // Multithread pack lhs and rhs - size_t RhsPackedStride = 0; - std::byte* RhsPackedData = nullptr; - - RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); - RhsPacked = std::make_unique(RhsPackedStride * BatchSize); - RhsPackedData = RhsPacked.get(); - - MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { - // lhs odd, rhs even - if (batch_idx & 0x1) { - batch_idx >>= 1; - - std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); - } else { - batch_idx >>= 1; - - std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); - - ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr); - - KaiPackedData[batch_idx].B = reinterpret_cast(RhsPackedPtr); - } - }); - } + if (M < m_step && N < n_step && !Data->BIsPacked) { + // Fallback to MLAS + return false; + } - // tile iteration dimensions - std::array dim; - dim[0] = BatchSize; // B - dim[1] = MlasDivRoundup(M, m_step); // M - dim[2] = MlasDivRoundup(N, n_step); // N + std::vector KaiPackedData; + KaiPackedData.resize(BatchSize); - // Minimize the kernel call count for the number of available threads - auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; - // scale required tiles over available tile processors - dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); - dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); + auto LhsPacked = std::make_unique(LhsPackedStride * BatchSize); + LhsPackedData = LhsPacked.get(); - // compute new step sizes - m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); - n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + std::unique_ptr RhsPacked{nullptr}; - // update tile iterations - dim[1] = MlasDivRoundup(M, m_step); - dim[2] = MlasDivRoundup(N, n_step); + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); + KaiPackedData[batch_idx].B = Data[batch_idx].B; + }); + } else { + // Multithread pack lhs and rhs + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; - MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { - // compute B,M,N index from iteration index - ptrdiff_t BIdx = tid / (dim[1] * dim[2]); - ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; - ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); + RhsPacked = std::make_unique(RhsPackedStride * BatchSize); + RhsPackedData = RhsPacked.get(); - // Get rhs tile, B - const size_t rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K); + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + // lhs odd, rhs even + if (batch_idx & 0x1) { + batch_idx >>= 1; - auto BTile = reinterpret_cast( - reinterpret_cast(KaiPackedData[BIdx].B) + rhs_packed_offset - ); + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); - // Get lhs tile, A - const size_t lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K); + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); - auto ATile = reinterpret_cast( - reinterpret_cast(KaiPackedData[BIdx].A) + lhs_packed_offset - ); + KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); + } else { + batch_idx >>= 1; - auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; - auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); - // Get result tile, C - auto CTile = reinterpret_cast( - reinterpret_cast(Data[BIdx].C) + - MIdx * m_step * Data[BIdx].ldc * sizeof(float) + - NIdx * n_step * sizeof(float) - ); - // Allocate temporary buffer for raw A*B result - std::vector OutputTile(TileSizeM * TileSizeN, 0.0f); - float* temp_tile = OutputTile.data(); + ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr); + KaiPackedData[batch_idx].B = reinterpret_cast(RhsPackedPtr); + } + }); + } + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K) + : kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K); + + auto BTile = reinterpret_cast( + reinterpret_cast(KaiPackedData[BIdx].B) + rhs_packed_offset + ); + + // Get lhs tile, A + const size_t lhs_packed_offset = + UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K) + : kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K); + + auto ATile = reinterpret_cast( + reinterpret_cast(KaiPackedData[BIdx].A) + lhs_packed_offset + ); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + // Allocate temporary buffer for raw A*B result + std::vector OutputTile(TileSizeM * TileSizeN, 0.0f); + float* temp_tile = OutputTile.data(); + + if (UseSME2) { kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( TileSizeM, TileSizeN, @@ -304,9 +349,19 @@ ArmKleidiAI::MlasGemmBatch( TileSizeN * sizeof(float), sizeof(float), -std::numeric_limits::max(), std::numeric_limits::max() ); + } else { + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } - // Final output tile pointer - float* dst_tile = reinterpret_cast(CTile); + // Final output tile pointer + float* dst_tile = reinterpret_cast(CTile); // quick copy of data in cases where we are not scaling or accumulating anything // with bounds checking on tile sizing to ensure the data fits in the memory block @@ -350,8 +405,4 @@ ArmKleidiAI::MlasGemmBatch( return; }); return true; - } - else { - return false; - } }