diff --git a/cmake/deps.txt b/cmake/deps.txt index 71218fd049afb..8b6002f299279 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -58,4 +58,4 @@ composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/arch directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1 dawn;https://github.com/google/dawn/archive/4cb1f9be152a4fa6bb695c08cd707ab078a1e2fb.zip;de39336b7715f53c14eec61072293b85cc73b691 -kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.4.0.tar.gz;22d3b57b54a61c194ab256ff11b0353a3b220244 +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.5.0.tar.gz;d3925c2658b4494d54d8c49dc7b1e1d9b069ec3b diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index de23444e95778..d16c55695772b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -71,7 +71,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, const T* weights_data, size_t weight_matrix_col_size, /*out*/ PrePackedWeights* prepacked_weights) { - size_t packb_size = MlasGemmPackBSize(head_size, input_hidden_size); + size_t packb_size = MlasGemmPackBSize(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size); if (packb_size == 0) { return false; } @@ -87,7 +87,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, memset(packed_weights_data, 0, packed_weights_data_size); for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); packed_weights_data += packb_size; weights_data += head_size; } diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 04172e46e9c6a..ea29ec6a23ec0 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -157,6 +157,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -278,6 +279,7 @@ void CPUIDInfo::ArmAppleInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); // Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect // to encounter on Apple platforms. diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 4c9e7e80db49b..19c53d23e40d4 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -31,6 +31,7 @@ class CPUIDInfo { bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + bool HasArm_SME() const { return has_arm_sme_; } uint32_t GetCurrentCoreIdx() const; @@ -117,6 +118,7 @@ class CPUIDInfo { bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; + bool has_arm_sme_{false}; #if defined(CPUIDINFO_ARCH_X86) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index db21157d2fdce..b99784672c81e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -197,6 +197,119 @@ MlasActivation( size_t ldc ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + // // Matrix/matrix multiply routines. // C := alpha * op(A) * op(B) + beta * C @@ -312,17 +425,36 @@ MlasGemm( MLAS_THREADPOOL* ThreadPool ) { - MLAS_SGEMM_DATA_PARAMS Data; - Data.alpha = alpha; - Data.A = A; - Data.lda = lda; - Data.B = B; - Data.ldb = ldb; - Data.beta = beta; - Data.C = C; - Data.ldc = ldc; - - MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool); +#if defined(__aarch64__) && defined(__linux__) + if (TransA == CblasNoTrans && TransB == CblasNoTrans) { + MLAS_SBGEMM_DATA_PARAMS Data; + Data.BIsfp32 = true; + Data.AIsfp32 = true; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.C = C; + Data.ldc = N; + Data.Bias = nullptr; + Data.OutputProcessor = nullptr; + + MlasSBGemmBatch(M, N, K, 1, &Data, ThreadPool); + } else +#endif + { + MLAS_SGEMM_DATA_PARAMS Data; + Data.alpha = alpha; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.beta = beta; + Data.C = C; + Data.ldc = ldc; + + MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool); + } } /** @@ -685,6 +817,8 @@ MlasSymmQgemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ); @@ -692,6 +826,7 @@ MlasGemmPackBSize( void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -1803,119 +1938,6 @@ MlasHalfGemmConvertPackB( void* PackedB ); -#if defined(__aarch64__) && defined(__linux__) -/** - * @brief Whether current CPU supports Bfloat16(bf16) acceleration. - */ -bool MLASCALL -MlasBf16AccelerationSupported(); - -/** - * @brief Interface for bf16 gemm post processors. - * - * Example implementation of this interface includes activations, - * conversion from single precision to precision, etc. - * - * SBGEMM is computed tile by tile. When a tile of result matrix - * is produced, the method Process() is called to process this tile. - * Parameters of this method describe the location and shape of the - * tile. - */ -class MLAS_SBGEMM_POSTPROCESSOR -{ - public: - virtual void Process(float*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} -}; - -/** - * @brief bfloat16 precision activation functions, with optional sum tensor. - * Supplied sum tensor must be the same layout as the GEMM output tensor. - * And the supplied sum tensor will be added to the tensor before activation. - */ -class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR -{ - public: - MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) - : Activation_(Activation), SumBuf_(SumBuf) - { - } - - void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) - const override; - - private: - const MLAS_ACTIVATION& Activation_; - const float* SumBuf_; -}; - -/** - * @brief Data parameters for bfloat16 precision GEMM routine - * All except C are [in] parameters - */ -struct MLAS_SBGEMM_DATA_PARAMS { - const void* A = nullptr; /**< address of A */ - const void* B = nullptr; /**< address of B */ - const float* Bias = nullptr; /**< address of Bias, vector size N */ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ - size_t ldc = 0; /**< leading dimension of C*/ - const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; - bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ - bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ -}; - -/** - * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias - * Either B can be either fp32 or bf16 - * - * Note: We only support uniform batching, so shapes and types of the - * input must be same across all parameter blocks. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); - -/** - * @brief For bfloat16 precision GEMM, returns size of the - * packing buffer needed for right hand side - * @param[in] N Number of columns - * @param[in] K Number of rows - * @return size of the packing buffer, - * 0 if operation not supported - */ -size_t MLASCALL -MlasSBGemmPackBSize(size_t N, size_t K); - -/** - * @brief For bfloat16 precision GEMM, convert the float matrix B - * to blfoat16 precision and pack it into a packing buffer - * - * @param[in] N Number of columns - * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B - * @param[out] PackedB Address of the packed matrix - */ -void MLASCALL -MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); -#endif - /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index de7fd72fad45a..e4c635df8e797 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -39,6 +39,12 @@ Module Name: #include "mlasi.h" +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" +#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" +#endif + /** * @brief Define the default striding parameters for * the bfloat16 precision gemm operation @@ -223,7 +229,21 @@ MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, size_t RangeStartM; size_t RangeCountM; - MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + const size_t m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); + const size_t BlockedM = (M + m_step - 1) / m_step; + + MlasPartitionWork(ThreadIdM, ThreadCountM, BlockedM, &RangeStartM, &RangeCountM); + + RangeStartM *= m_step; + RangeCountM *= m_step; + RangeCountM = std::min(RangeCountM, M - RangeStartM); + } else +#endif + { + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + } // // Partition the operation along the N dimension. @@ -231,34 +251,65 @@ MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, size_t RangeStartN; size_t RangeCountN; - const size_t BlockedN = - (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - + size_t n_step = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); + } +#endif + const size_t BlockedN = (N + n_step - 1) / n_step; MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); - RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - RangeCountN = std::min(N - RangeStartN, RangeCountN); + RangeStartN *= n_step; + RangeCountN *= n_step; + RangeCountN = std::min(RangeCountN, N - RangeStartN); // // Dispatch the partitioned operation. // - const size_t lda = DataParams->lda; const size_t ldc = DataParams->ldc; - const float* A = (const float*)DataParams->A + RangeStartM * lda; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* bias = DataParams->Bias; - if (!DataParams->BIsfp32) { - MlasSBGemmPackedOperation( - RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, - lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor - ); - } else { - const size_t ldb = DataParams->ldb; - const float* B = (const float*)DataParams->B + RangeStartN; - MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + const size_t lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa( + RangeStartM, K); + const void* A = reinterpret_cast(reinterpret_cast(DataParams->A) + lhs_offset); + + + const size_t rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa( + RangeStartN, K); + const void* B = reinterpret_cast(reinterpret_cast(DataParams->B) + rhs_offset); + + const size_t dst_stride = ldc * sizeof(float); + kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(RangeCountM, RangeCountN, K, + A, B, C, dst_stride, sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max()); + + if (bias != nullptr) { + for (size_t m = 0; m < RangeCountM; m++) { + for (size_t n = 0; n < RangeCountN; n++) { + C[m * ldc + n] += bias[n]; + } + } + } + } else +#endif + { + const size_t lda = DataParams->lda; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } } } @@ -306,13 +357,22 @@ MlasSBGemmPackBSize(size_t N, size_t K) const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return 0; - const auto padding = dispatch->BufOverRead; - const auto PackedK = dispatch->PackedK; - const auto PackedN = dispatch->PackedN; + size_t BytesRequired; +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + BytesRequired = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(N, K); + } else +#endif + { + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + } - const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); - const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); @@ -335,6 +395,49 @@ MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t Bat const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; + std::vector PackedData; +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + const size_t mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); + + const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr); + PackedData.resize(BatchN); + std::vector LhsPacked(LhsPackedStride * BatchN); + std::byte *LhsPackedData = LhsPacked.data(); + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const MLAS_SBGEMM_DATA_PARAMS* Params = &(Data[gemm_idx]); + std::byte *LhsPackedPtr = &(LhsPackedData[LhsPackedStride * gemm_idx]); + + kai_run_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr, 0, + Params->A, Params->lda * sizeof(float), LhsPackedPtr); + + MLAS_SBGEMM_DATA_PARAMS* PackedParams = &(PackedData[gemm_idx]); + *PackedParams = *Params; + PackedParams->A = LhsPackedPtr; + }); + + size_t RhsPackedStride = 0; + std::vector RhsPacked; + std::byte *RhsPackedData = nullptr; + if (Data[0].ldb != 0) { + RhsPackedStride = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(N, K); + RhsPacked.resize(RhsPackedStride * BatchN); + RhsPackedData = RhsPacked.data(); + + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + MLAS_SBGEMM_DATA_PARAMS* PackedParams = &(PackedData[gemm_idx]); + std::byte *RhsPackedPtr = &(RhsPackedData[RhsPackedStride * gemm_idx]); + MlasSBGemmConvertPackB(N, K, reinterpret_cast(PackedParams->B), PackedParams->ldb, RhsPackedPtr); + + PackedParams->B = RhsPackedPtr; + PackedParams->ldb = 0; + }); + } + } +#endif + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; // @@ -392,7 +495,9 @@ MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t Bat ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { ptrdiff_t GemmIdx = tid / ThreadsPerGemm; ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; - operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + + const MLAS_SBGEMM_DATA_PARAMS &Params = (PackedData.empty() ? Data : PackedData.data())[GemmIdx]; + operation(ThreadCountM, ThreadCountN, M, N, K, &Params, ThreadIdx); } ); } diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp index a6a73996c548b..06f5d22ee4291 100644 --- a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -21,6 +21,11 @@ Module Name: #include "mlasi.h" #include "sbgemm.h" +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" +#endif + struct MLAS_SBGEMM_KERNEL_NEON { static constexpr bool PackNeeded = true; static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process @@ -316,21 +321,35 @@ MlasSBGemmConvertPackB( const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; - const auto PackedN = dispatch->PackedN; +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + const std::vector bias(CountN); - const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa(); - // - // Step through each slice of matrix B along the K dimension. - // - size_t K_block_size; - constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(1, CountN, CountK, nr, kr, sr, + ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr); + } else +#endif + { + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); - for (size_t k = 0; k < CountK; k += K_block_size) { - K_block_size = std::min(CountK - k, Strides.K); + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; - MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); - PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } } } diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 616622a8c1f53..14b8afb223128 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -17,6 +17,12 @@ Module Name: #include "mlasi.h" +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#endif + // // Define the number of rows from matrix A to transpose to a local buffer. // @@ -26,6 +32,24 @@ Module Name: #define MLAS_SGEMM_TRANSA_ROWS 12 +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) +bool UseKleidiAISgemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, size_t K + ) +{ + if (TransA == CblasNoTrans && TransB == CblasNoTrans && MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + const size_t n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t BlockedN = (N + n_step - 1) / n_step; + const size_t AlignedN = BlockedN * n_step; + + return AlignedN > 64 && K > 1; + } + return false; +} +#endif + // // Define the parameters to execute segments of a SGEMM operation on worker // threads. @@ -1323,6 +1347,7 @@ Return Value: void MlasSgemmPackedOperation( CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t M, size_t RangeStartN, size_t RangeCountN, @@ -1347,6 +1372,8 @@ Routine Description: TransA - Supplies the transpose operation for matrix A. + TransB - Supplies the transpose operation for matrix B. + M - Supplies the number of rows of matrix A and matrix C. RangeStartN - Supplies the starting column from packed matrix B. @@ -1380,77 +1407,97 @@ Return Value: --*/ { - float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_PACKED_STRIDEK]; - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; +#if !(defined(USE_KLEIDIAI) && !defined(_MSVC_LANG)) + MLAS_UNREFERENCED_PARAMETER(TransB); +#endif - for (size_t n = 0; n < RangeCountN; n += CountN) { +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (UseKleidiAISgemm(TransA, TransB, AlignedN, K)) { + const size_t dst_stride = ldc * sizeof(float); - const size_t SliceStartN = RangeStartN + n; + const size_t rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(RangeStartN, K); + const void* rhs_ptr = reinterpret_cast( + reinterpret_cast(PackedB) + rhs_packed_offset); - CountN = std::min(RangeCountN - n, size_t(MLAS_SGEMM_PACKED_STRIDEN)); + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(M, RangeCountN, K, + A, rhs_ptr, C, dst_stride, sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max()); + } else +#endif + { + float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_PACKED_STRIDEK]; // - // Multiply the output matrix by beta as needed. + // Step through each slice of matrix B along the N dimension. // - if (beta != 0.0f && beta != 1.0f) { - MlasSgemmMultiplyBeta(C + n, M, CountN, ldc, beta); - } + size_t CountN; - // - // Step through each slice of matrix B along the K dimension. - // + for (size_t n = 0; n < RangeCountN; n += CountN) { - size_t CountK; - bool ZeroMode = (beta == 0.0f); + const size_t SliceStartN = RangeStartN + n; - for (size_t k = 0; k < K; k += CountK) { + CountN = std::min(RangeCountN - n, size_t(MLAS_SGEMM_PACKED_STRIDEN)); - CountK = std::min(K - k, size_t(MLAS_SGEMM_PACKED_STRIDEK)); + // + // Multiply the output matrix by beta as needed. + // + + if (beta != 0.0f && beta != 1.0f) { + MlasSgemmMultiplyBeta(C + n, M, CountN, ldc, beta); + } // - // Step through each slice of matrix A along the M dimension. + // Step through each slice of matrix B along the K dimension. // - const float* pb = (const float*)PackedB + AlignedN * k + CountK * SliceStartN; - float* c = C + n; + size_t CountK; + bool ZeroMode = (beta == 0.0f); - if (TransA == CblasNoTrans) { + for (size_t k = 0; k < K; k += CountK) { - MlasSgemmKernelLoop(A + k, pb, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); + CountK = std::min(K - k, size_t(MLAS_SGEMM_PACKED_STRIDEK)); - } else { + // + // Step through each slice of matrix A along the M dimension. + // - const float* a = A + k * lda; - size_t RowsRemaining = M; + const float* pb = (const float*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; - while (RowsRemaining > 0) { + if (TransA == CblasNoTrans) { - // - // Transpose elements from matrix A into a local buffer. - // + MlasSgemmKernelLoop(A + k, pb, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); - size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_SGEMM_TRANSA_ROWS)); + } else { - MlasSgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); + const float* a = A + k * lda; + size_t RowsRemaining = M; - RowsRemaining -= RowsTransposed; - a += RowsTransposed; + while (RowsRemaining > 0) { - // - // Step through the rows of the local buffer. - // + // + // Transpose elements from matrix A into a local buffer. + // + + size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_SGEMM_TRANSA_ROWS)); - c = MlasSgemmKernelLoop(PanelA, pb, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); + MlasSgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); + + RowsRemaining -= RowsTransposed; + a += RowsTransposed; + + // + // Step through the rows of the local buffer. + // + + c = MlasSgemmKernelLoop(PanelA, pb, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); + } } - } - ZeroMode = false; + ZeroMode = false; + } } } } @@ -1497,7 +1544,6 @@ Return Value: --*/ { - const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; @@ -1508,23 +1554,43 @@ Return Value: size_t RangeStartM; size_t RangeCountM; - MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (UseKleidiAISgemm(TransA, TransB, N, K) && DataParams->BIsPacked) { + const size_t m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t BlockedM = (M + m_step - 1) / m_step; + + MlasPartitionWork(ThreadIdM, ThreadCountM, BlockedM, &RangeStartM, &RangeCountM); + + RangeStartM *= m_step; + RangeCountM *= m_step; + RangeCountM = std::min(RangeCountM, M - RangeStartM); + } else +#endif + { + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + } // // Partition the operation along the N dimension. // + size_t n_step = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (UseKleidiAISgemm(TransA, TransB, N, K) && DataParams->BIsPacked) { + n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + } +#endif size_t RangeStartN; size_t RangeCountN; - const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + const size_t BlockedN = (N + n_step - 1) / + n_step; MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); - RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeStartN *= n_step; + RangeCountN *= n_step; RangeCountN = std::min(N - RangeStartN, RangeCountN); @@ -1539,10 +1605,17 @@ Return Value: float* C = DataParams->C + RangeStartM * ldc + RangeStartN; if (DataParams->BIsPacked) { - - MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN, +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (UseKleidiAISgemm(TransA, TransB, N, K)) { + const size_t lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(RangeStartM, K); + A = reinterpret_cast( + reinterpret_cast(DataParams->A) + lhs_packed_offset); + } +#endif + MlasSgemmPackedOperation(TransA, TransB, RangeCountM, RangeStartN, RangeCountN, K, DataParams->alpha, A, lda, DataParams->B, - BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, DataParams->beta, C, ldc); + BlockedN * n_step, DataParams->beta, C, ldc); } else { @@ -1572,6 +1645,59 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { + if (M == 0 || N == 0 || K == 0) { + return; + } + + std::vector PackedData; +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + const bool use_kleidiai = UseKleidiAISgemm(TransA, TransB, N, K) && (M > 1 || Data[0].BIsPacked); + size_t LhsPackedStride = 0; + std::vector LhsPacked; + std::byte *LhsPackedData = nullptr; + + size_t RhsPackedStride = 0; + std::vector RhsPacked; + std::byte *RhsPackedData = nullptr; + if (use_kleidiai) { + const size_t mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + + PackedData.resize(BatchSize); + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); + LhsPacked.resize(LhsPackedStride * BatchSize); + LhsPackedData = LhsPacked.data(); + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t gemm_idx) { + const MLAS_SGEMM_DATA_PARAMS* Params = &(Data[gemm_idx]); + std::byte *LhsPackedPtr = &(LhsPackedData[LhsPackedStride * gemm_idx]); + + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, + Params->A, Params->lda * sizeof(float), LhsPackedPtr); + + MLAS_SGEMM_DATA_PARAMS* PackedParams = &(PackedData[gemm_idx]); + *PackedParams = *Params; + PackedParams->A = reinterpret_cast(LhsPackedPtr); + }); + + if (!Data[0].BIsPacked) { + RhsPackedStride = MlasGemmPackBSize(TransA, TransB, N, K); + RhsPacked.resize(RhsPackedStride * BatchSize); + RhsPackedData = RhsPacked.data(); + + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t gemm_idx) { + MLAS_SGEMM_DATA_PARAMS* PackedParams = &(PackedData[gemm_idx]); + std::byte *RhsPackedPtr = &(RhsPackedData[RhsPackedStride * gemm_idx]); + MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast(PackedParams->B), + PackedParams->ldb, RhsPackedPtr); + + PackedParams->B = reinterpret_cast(RhsPackedPtr); + PackedParams->ldb = 0; + PackedParams->BIsPacked = true; + }); + } + } +#endif // // Compute the number of target threads given the complexity of the SGEMM @@ -1626,8 +1752,9 @@ MlasGemmBatch( { ptrdiff_t GemmIdx = tid / ThreadsPerGemm; ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + const MLAS_SGEMM_DATA_PARAMS *Params = PackedData.empty()? &(Data[GemmIdx]) : &(PackedData[GemmIdx]); MlasSgemmThreaded(ThreadCountM, ThreadCountN, - TransA, TransB, M, N, K, &(Data[GemmIdx]), ThreadIdx); + TransA, TransB, M, N, K, Params, ThreadIdx); }); } #if defined(_MSC_VER) && !defined(__clang__) @@ -1637,6 +1764,8 @@ MlasGemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ) @@ -1648,6 +1777,10 @@ Routine Description: Arguments: + TransA - Supplies the transpose operation on A matrix + + TransB - Supplies the transpose operation on B matrix + N - Supplies the number of columns of matrix B. K - Supplies the number of rows of matrix B. @@ -1658,14 +1791,26 @@ Return Value: --*/ { +#if !(defined(USE_KLEIDIAI) && !defined(_MSVC_LANG)) + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); +#endif // // Compute the number of bytes required to hold the packed buffer. // - const size_t AlignedN = - (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); + size_t BytesRequired; +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (UseKleidiAISgemm(TransA, TransB, N, K)) { + BytesRequired = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K); + } else +#endif + { + const size_t AlignedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); + BytesRequired = AlignedN * K * sizeof(float); + } - const size_t BytesRequired = AlignedN * K * sizeof(float); const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); @@ -1676,6 +1821,7 @@ Return Value: void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -1694,6 +1840,8 @@ Routine Description: Arguments: + TransA - Supplies the transpose operation for matrix A. + TransB - Supplies the transpose operation for matrix B. N - Supplies the number of columns of matrix B. @@ -1712,25 +1860,42 @@ Return Value: --*/ { - const size_t AlignedN = - (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); +#if !(defined(USE_KLEIDIAI) && !defined(_MSVC_LANG)) + MLAS_UNREFERENCED_PARAMETER(TransA); +#endif - // - // Step through each slice of matrix B along the K dimension. - // +#if defined(USE_KLEIDIAI) && !defined(_MSVC_LANG) + if (UseKleidiAISgemm(TransA, TransB, N, K)) { + const std::vector bias(N); + + const size_t nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, + ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr); + } else +#endif + { + const size_t AlignedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); - size_t CountK; + // + // Step through each slice of matrix B along the K dimension. + // - for (size_t k = 0; k < K; k += CountK) { + size_t CountK; - CountK = std::min(K - k, size_t(MLAS_SGEMM_PACKED_STRIDEK)); + for (size_t k = 0; k < K; k += CountK) { - if (TransB == CblasNoTrans) { - MlasSgemmCopyPackB((float*)PackedB, B + k * ldb, ldb, N, CountK); - } else { - MlasSgemmTransposePackB((float*)PackedB, B + k, ldb, N, CountK); - } + CountK = std::min(K - k, size_t(MLAS_SGEMM_PACKED_STRIDEK)); - PackedB = (float*)PackedB + AlignedN * CountK; + if (TransB == CblasNoTrans) { + MlasSgemmCopyPackB((float*)PackedB, B + k * ldb, ldb, N, CountK); + } else { + MlasSgemmTransposePackB((float*)PackedB, B + k, ldb, N, CountK); + } + + PackedB = (float*)PackedB + AlignedN * CountK; + } } } diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5406dd1a40446..4e2129ca726e5 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -102,6 +102,7 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( bool GemmPackBFp32(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, @@ -116,7 +117,7 @@ bool GemmPackBFp32(AllocatorPtr& alloc, const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); - packed_b_size = MlasGemmPackBSize(N, K); + packed_b_size = MlasGemmPackBSize(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, N, K); if (packed_b_size == 0) { return false; } @@ -129,7 +130,8 @@ bool GemmPackBFp32(AllocatorPtr& alloc, // if and when we try to cache this pre-packed buffer for sharing between sessions. memset(packed_b_data, 0, packed_b_size); - MlasGemmPackB(trans_b ? CblasTrans : CblasNoTrans, + MlasGemmPackB(trans_a ? CblasTrans : CblasNoTrans, + trans_b ? CblasTrans : CblasNoTrans, N, K, tensor_b.Data(), @@ -263,7 +265,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, // only pack Matrix B if (input_idx == 1) { size_t packed_b_size; - is_packed = GemmPackBFp32(alloc, tensor, trans_B_ != CblasNoTrans, packed_b_, packed_b_size, b_shape_); + is_packed = GemmPackBFp32(alloc, tensor, trans_A_ != CblasNoTrans, trans_B_ != CblasNoTrans, packed_b_, packed_b_size, b_shape_); bool share_prepacked_weights = (prepacked_weights != nullptr); if (is_packed && share_prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); diff --git a/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h b/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h index 599847e61a54f..026df3838b943 100644 --- a/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h +++ b/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h @@ -9,6 +9,7 @@ namespace onnxruntime { bool GemmPackBFp32(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de908..8a91bd01116de 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -195,7 +195,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc } else #endif { - is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + is_packed = GemmPackBFp32(alloc, tensor, false, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); } bool share_prepacked_weights = (prepacked_weights != nullptr); diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index c0171f7728ea8..d781de2eb5541 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -194,7 +194,7 @@ bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& allo const size_t N = static_cast(shape[1]); const size_t K = static_cast(shape[2]); - const size_t packed_weights_size = MlasGemmPackBSize(N, K); + const size_t packed_weights_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, N, K); if (packed_weights_size == 0) { return false; } @@ -215,7 +215,7 @@ bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& allo const size_t N_x_K = N * K; const auto* weights_data = weights.Data(); for (int64_t dir = 0; dir < num_directions; ++dir) { - MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasTrans, N, K, weights_data, K, packed_weights_data); weights_data += N_x_K; packed_weights_data += packed_weights_size; } @@ -244,12 +244,12 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& const auto hidden_size_x_2 = N - hidden_size_; // We are making two packed buffers, one for ZR weights and another for H weights. - const size_t ZR_packed_size = MlasGemmPackBSize(narrow(hidden_size_x_2), narrow(K)); + const size_t ZR_packed_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K)); if (ZR_packed_size == 0) { return false; } - const size_t H_packed_size = MlasGemmPackBSize(narrow(hidden_size_), narrow(K)); + const size_t H_packed_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K)); if (H_packed_size == 0) { return false; } @@ -275,18 +275,18 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& const auto hidden_2_step = hidden_size_x_2 * K; const auto hidden_1_step = hidden_size_ * K; // square const auto* weights_data = weights.Data(); - MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); weights_data += hidden_2_step; - MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); if (num_directions == 2) { weights_data += hidden_1_step; buffer_ZR = static_cast(buffer_ZR) + ZR_packed_size; - MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); weights_data += hidden_2_step; buffer_H = static_cast(buffer_H) + H_packed_size; - MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); } return true; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index e95ad707cf2b0..b38e271fdbe4a 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -196,7 +196,7 @@ Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packe return Status::OK(); } - const size_t packed_weights_size = MlasGemmPackBSize(N, K); + const size_t packed_weights_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, N, K); if (packed_weights_size == 0) { return Status::OK(); } @@ -217,7 +217,7 @@ Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packe const auto* weights_data = weights.Data(); for (int i = 0; i < num_directions_; i++) { - MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasTrans, N, K, weights_data, K, packed_weights_data); packed_weights_data = static_cast(packed_weights_data) + packed_weights_size; weights_data += N * K; } diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index a94d33cd77f63..0eb552083e8b9 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -30,9 +30,13 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); if (pack_b) { - size_t pack_b_size = MlasGemmPackBSize(N, K); + size_t pack_b_size = MlasGemmPackBSize( + trans_a ? CblasTrans : CblasNoTrans, + CblasNoTrans, N, K); std::vector B_packed(pack_b_size); - MlasGemmPackB(CblasNoTrans, N, K, B.data(), N, B_packed.data()); + MlasGemmPackB( + trans_a ? CblasTrans : CblasNoTrans, + CblasNoTrans, N, K, B.data(), N, B_packed.data()); MlasGemm( trans_a ? CblasTrans : CblasNoTrans, diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.h b/onnxruntime/test/mlas/unittest/test_fgemm.h index 2bd094152d6f0..e7741fba1c3fb 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm.h @@ -112,11 +112,11 @@ class FgemmPackedContext { float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { - size_t PackedBSize = MlasGemmPackBSize(N, K); + size_t PackedBSize = MlasGemmPackBSize(TransA, TransB, N, K); void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); std::vector data(BatchSize); for (size_t i = 0; i < BatchSize; i++) { - MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); + MlasGemmPackB(TransA, TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); data[i].BIsPacked = true; data[i].A = A + M * K * i; data[i].lda = lda;