diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c0ab948b41fff..80451377afb19 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -279,6 +279,7 @@ function(setup_kleidiai) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp ) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 248c6d74e6cbd..0d6c1152e09b8 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1955,6 +1955,7 @@ struct MLAS_SBGEMM_DATA_PARAMS { const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ + bool BIsPacked = false; /**< Whether B is pre-packed */ }; /** diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index fdada83cc6582..a30e655faa1d1 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -12,6 +12,8 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod = {kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, @@ -64,6 +66,19 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm}; +const kai_matmul_clamp_f32_bf16p_bf16p_ukernel sbgemm_gemm_sme2 = + {kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa}; + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() { if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm; @@ -79,3 +94,8 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() { return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod; } } + +const kai_matmul_clamp_f32_bf16p_bf16p_ukernel& GetKleidiAISBGemmUKernel() { + // Currently only SME2 variant exists for bfloat16/SBGEMM kernel + return sbgemm_gemm_sme2; +} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index 1a6f111d1c794..9594db92b3d0f 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -8,5 +8,9 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h" + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel(); const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel(); + +const kai_matmul_clamp_f32_bf16p_bf16p_ukernel& GetKleidiAISBGemmUKernel(); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 216eb35a9b6cc..6bcd4ecb935b7 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -90,6 +90,36 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ); +#if defined(__aarch64__) && defined(__linux__) +size_t +MLASCALL +MlasSBGemmPackBSize( + size_t N, + size_t K + ); + +bool +MLASCALL +MlasSBGemmPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasSBGemmBatch( + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); +#endif + size_t MLASCALL MlasDynamicQgemmPackBSize( diff --git a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp new file mode 100644 index 0000000000000..ed1adfa1e7574 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp @@ -0,0 +1,371 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#if defined(__aarch64__) && defined(__linux__) + +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" +#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + +#include "mlas.h" + +#include "mlasi_kleidiai.h" +#include "kai_ukernel_interface.h" + +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffers { + std::vector output_tile; + std::vector bias_zero; + std::vector rhs_packed; + std::vector lhs_packed; + std::vector gemv_lhs_row_tmp; +}; +static thread_local KaiTlsBuffers g_kai_tls; + +kai_matmul_clamp_f32_bf16p_bf16p_ukernel sbgemm_gemm = GetKleidiAISBGemmUKernel(); + +/*++ +Routine Description: + Apply bias to a 2-D tile (rows x cols). + +Arguments: + src - Pointer to the temporary A*B results (row-major, rows x cols). + rows - Number of rows in the tile. + cols - Number of columns in the tile. + bias - Pointer to the bias vector or nullptr if no bias. + dst - Pointer to the destination tile in C (row-major with leading dimension ldc). + ldc - Leading dimension of C (in elements). + start_col - Starting column index of the tile (NIdx * n_step). + +Notes: + Uses a row by row memcpy path when no bias. +--*/ +static inline void ApplyBias2D(const float* src, + size_t rows, + size_t cols, + const float* bias, + float* dst, + size_t ldc, + size_t start_col) { + for (size_t i = 0; i < rows; ++i) { + const float* src_row = src + i * cols; + float* dst_row = dst + i * ldc; + + if (bias != nullptr) { + for (size_t j = 0; j < cols; ++j) { + dst_row[j] = src_row[j] + bias[start_col + j]; + } + } else { + // No bias but can't memcpy whole so needs to be done row by row. + memcpy(dst_row, src_row, cols * sizeof(float)); + } + } +} + +size_t +MLASCALL +ArmKleidiAI::MlasSBGemmPackBSize( + size_t N, + size_t K +) +/*++ + +Routine Description: + + This routine computes the length in bytes for the packed matrix B buffer. + +Arguments: + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + +Return Value: + + Returns the size in bytes for the packed matrix B buffer. + +--*/ +{ + if (N == 0 || K == 0) { + KLEIDIAI_DEBUG_LOG("MlasSBGemmPackBSize returning 0 size. N=" << N << " K=" << K); + return 0; + } + // + // Compute the number of bytes required to hold the packed buffer. + // + size_t bytes = 0; + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(N, K); + + return bytes; +} + +bool +MLASCALL +ArmKleidiAI::MlasSBGemmPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB +) +/*++ + +Routine Description: + + This routine packs the contents of matrix B to the destination buffer. The + destination buffer should be sized based on MlasSBGemmPackBSize(). For best + performance, the destination buffer should be aligned to the value returned + from MlasGetPreferredBufferAlignment(). + +Arguments: + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + + B - Supplies the address of matrix B. + + ldb - Supplies the first dimension of matrix B. + + PackedB - Supplies the address of packed matrix B. + +Return Value: + + Returns true if the packing operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (N == 0 || K == 0) { + KLEIDIAI_DEBUG_LOG("MlasSBGemmPackB one of N or K is 0, falling back to MLAS."); + return false; + } + + const size_t nr = sbgemm_gemm.get_nr(); + const size_t kr = sbgemm_gemm.get_kr(); + const size_t sr = sbgemm_gemm.get_sr(); + + // Ensure size and zero the used span. + g_kai_tls.bias_zero.resize(N, 0.0f); + + kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr); + + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasSBGemmBatch( + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) +/*++ + +Routine Description: + + This routine performs a bfloat16 batched matrix multiplication (SBGEMM) operation using KleidiAI kernels. + If packing is needed, it prepares the required buffers and invokes the + appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. + +Arguments: + + M - Supplies the number of rows of matrix A and matrix C. + + N - Supplies the number of columns of matrix B and matrix C. + + K - Supplies the number of columns of matrix A and rows of matrix B. + + Data - Supplies a pointer to the MLAS_SBGEMM_DATA_PARAMS array containing per-batch input/output pointers and parameters. + + BatchSize - Supplies the number of independent GEMM computations to perform in the batch. + + ThreadPool - Supplies the thread pool to parallelize computation across batches and tiles. + +Return Value: + + Returns true if the GEMM operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (M == 0 || N == 0) { + return true; + } + + size_t m_step = sbgemm_gemm.get_m_step(); + size_t n_step = sbgemm_gemm.get_n_step(); + + if ((M < m_step || N < n_step) && !Data->BIsPacked) { + // Fallback + return false; + } + + const size_t mr = sbgemm_gemm.get_mr(); + const size_t kr = sbgemm_gemm.get_kr(); + const size_t sr = sbgemm_gemm.get_sr(); + + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; + + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr); + + size_t lhs_resize = 0; + if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + { + // size_t wraparound detected for LhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls.lhs_packed.resize(lhs_resize); + LhsPackedData = g_kai_tls.lhs_packed.data(); + + // RHS packed buffer: use TLS reusable vector to minimize allocations + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; + + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_bf16p2vlx2_f32_sme" << " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr); + kai_run_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + }); + } else { + // Multithread pack lhs and rhs + RhsPackedStride = ArmKleidiAI::MlasSBGemmPackBSize(N, K); + size_t rhs_resize = 0; + if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + { + // size_t wraparound detected for RhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls.rhs_packed.resize(rhs_resize); + RhsPackedData = g_kai_tls.rhs_packed.data(); + + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + if (batch_idx & 0x1) { + batch_idx >>= 1; + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_bf16p2vlx2_f32_sme" + << " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr); + kai_run_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + } else { + batch_idx >>= 1; + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); + ArmKleidiAI::MlasSBGemmPackB(N, K, + reinterpret_cast(Data[batch_idx].B), + Data[batch_idx].ldb, RhsPackedPtr); + } + }); + } + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. + // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. + size_t max_tile_elems = 0; + if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + // size_t wraparound detected for tile size, fallback to MLAS + return false; + } + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = sbgemm_gemm.get_rhs_packed_offset(NIdx * n_step, K); + + const std::byte* B_base = Data[0].BIsPacked + ? reinterpret_cast(Data[BIdx].B) + : (RhsPackedData + RhsPackedStride * BIdx); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset = sbgemm_gemm.get_lhs_packed_offset(MIdx * m_step, K); + + const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + // Allocate temporary buffer for raw A*B result (TLS reusable buffer) + size_t tile_elems = TileSizeM * TileSizeN; + + // resize the tile to the required size + g_kai_tls.output_tile.resize(tile_elems); + + float* temp_tile = g_kai_tls.output_tile.data(); + std::fill_n(temp_tile, tile_elems, 0.0f); + + sbgemm_gemm.run_matmul( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + + // Final output tile pointer + float* dst_tile = reinterpret_cast(CTile); + const float* bias = Data[BIdx].Bias; + + // quick copy of data in cases where we are not applying bias + // with bounds checking on tile sizing to ensure the data fits in the memory block + bool can_memcpy = ( + bias == nullptr && + Data[BIdx].ldc == TileSizeN && + MIdx * m_step + TileSizeM <= M && + NIdx * n_step + TileSizeN <= N && + TileSizeM != 0 && + TileSizeN != 0); + + if (can_memcpy) { + std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); + return; + } + + ApplyBias2D(temp_tile, TileSizeM, TileSizeN, bias, dst_tile, Data[BIdx].ldc, NIdx * n_step); + return; + }); + return true; +} +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ad62cccbfb9c7..6abf3f39e467b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -901,6 +901,33 @@ typedef bool (MLASCALL MLAS_GEMM_PACK_B_OVERRIDE)( size_t ldb, void* PackedB); +#if defined(__aarch64__) && defined(__linux__) +typedef +bool +(MLASCALL MLAS_SBGEMM_BATCH_OVERRIDE)( + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef +size_t +(MLASCALL MLAS_SBGEMM_PACK_B_SIZE_OVERRIDE)( + size_t N, + size_t K); + +typedef +bool +(MLASCALL MLAS_SBGEMM_PACK_B_OVERRIDE)( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); +#endif + extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) @@ -1332,6 +1359,13 @@ struct MLAS_PLATFORM { MLAS_GEMM_PACK_B_OVERRIDE* MlasGemmPackBOverride = nullptr; MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; +#if defined(__aarch64__) && defined(__linux__) + // SBGemm overrides + MLAS_SBGEMM_BATCH_OVERRIDE* MlasSBGemmBatchOverride = nullptr; + MLAS_SBGEMM_PACK_B_SIZE_OVERRIDE* MlasSBGemmPackBSizeOverride = nullptr; + MLAS_SBGEMM_PACK_B_OVERRIDE* MlasSBGemmPackBOverride = nullptr; +#endif + #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 528e71bcffed1..792186780c89f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -607,6 +607,14 @@ Return Value: this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; this->MlasConvOverride = ArmKleidiAI::MlasConv; +#if defined(__aarch64__) && defined(__linux__) + // Currently only an SME2 variant of SBGEMM exists + if(ArmKleidiAI::UseSME2){ + this->MlasSBGemmBatchOverride = ArmKleidiAI::MlasSBGemmBatch; + this->MlasSBGemmPackBSizeOverride = ArmKleidiAI::MlasSBGemmPackBSize; + this->MlasSBGemmPackBOverride = ArmKleidiAI::MlasSBGemmPackB; + } +#endif } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index de7fd72fad45a..847db9a2bb4dd 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -303,6 +303,16 @@ MlasSBGemmPackBSize(size_t N, size_t K) // // Compute the number of bytes required to hold the packed buffer. // +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr) { + size_t bytes_required; + bytes_required = GetMlasPlatform().MlasSBGemmPackBSizeOverride(N, K); + if (bytes_required != 0){// If ArmKleidiAI::MlasSBGemmPackBSize ran to completion + return bytes_required; + } + } +#endif + const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return 0; @@ -323,6 +333,13 @@ MlasSBGemmPackBSize(size_t N, size_t K) void MLASCALL MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && + GetMlasPlatform().MlasSBGemmPackBOverride(N, K, B, ldb, PackedB)){ + return; + } +#endif + const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; @@ -332,6 +349,13 @@ MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* Pac void MLASCALL MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if(GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && + GetMlasPlatform().MlasSBGemmBatchOverride(M, N, K, Data, BatchN, ThreadPool)){ + return; + } +#endif + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 530218db31e3d..d76d127406fc7 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -272,6 +272,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { data[i].ldc = N; data[i].Bias = nullptr; data[i].OutputProcessor = nullptr; + data[i].BIsPacked = bool(packed_b_); } MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); } else diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h index 13701e2e3de46..56487f71b7fe6 100644 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.h +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -99,10 +99,12 @@ class MlasSBGemmTest : public MlasTestBase { params.ldc = ldc; params.AIsfp32 = true; params.BIsfp32 = true; + params.BIsPacked = false; if (Packed) { ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; params.B = PackB(N, K, B, ldb); + params.BIsPacked = true; params.ldb = 0; params.BIsfp32 = false; } else {