Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/arch
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1
dawn;https://github.com/google/dawn/archive/4cb1f9be152a4fa6bb695c08cd707ab078a1e2fb.zip;de39336b7715f53c14eec61072293b85cc73b691
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.4.0.tar.gz;22d3b57b54a61c194ab256ff11b0353a3b220244
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.5.0.tar.gz;d3925c2658b4494d54d8c49dc7b1e1d9b069ec3b
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
const T* weights_data,
size_t weight_matrix_col_size,
/*out*/ PrePackedWeights* prepacked_weights) {
size_t packb_size = MlasGemmPackBSize(head_size, input_hidden_size);
size_t packb_size = MlasGemmPackBSize(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size);
if (packb_size == 0) {
return false;
}
Expand All @@ -87,7 +87,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
memset(packed_weights_data, 0, packed_weights_data_size);

for (size_t i = 0; i < loop_len; i++) {
MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data);
MlasGemmPackB(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data);
packed_weights_data += packb_size;
weights_data += head_size;
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();
has_arm_sme_ = cpuinfo_has_arm_sme();

const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
Expand Down Expand Up @@ -278,6 +279,7 @@ void CPUIDInfo::ArmAppleInit() {
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();
has_arm_sme_ = cpuinfo_has_arm_sme();

// Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect
// to encounter on Apple platforms.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class CPUIDInfo {
bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; }
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }
bool HasArm_SME() const { return has_arm_sme_; }

uint32_t GetCurrentCoreIdx() const;

Expand Down Expand Up @@ -117,6 +118,7 @@ class CPUIDInfo {
bool has_arm_neon_i8mm_{false};
bool has_arm_sve_i8mm_{false};
bool has_arm_neon_bf16_{false};
bool has_arm_sme_{false};

#if defined(CPUIDINFO_ARCH_X86)

Expand Down
270 changes: 146 additions & 124 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,119 @@ MlasActivation(
size_t ldc
);

#if defined(__aarch64__) && defined(__linux__)
/**
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
*/
bool MLASCALL
MlasBf16AccelerationSupported();

/**
* @brief Interface for bf16 gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from single precision to precision, etc.
*
* SBGEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_SBGEMM_POSTPROCESSOR
{
public:
virtual void Process(float*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_SBGEMM_POSTPROCESSOR() {}
};

/**
* @brief bfloat16 precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
* And the supplied sum tensor will be added to the tensor before activation.
*/
class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR
{
public:
MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr)
: Activation_(Activation), SumBuf_(SumBuf)
{
}

void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc)
const override;

private:
const MLAS_ACTIVATION& Activation_;
const float* SumBuf_;
};

/**
* @brief Data parameters for bfloat16 precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_SBGEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const float* Bias = nullptr; /**< address of Bias, vector size N */
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
};

/**
* @brief Bfloat16 precision Batched GEMM: C = A * B + Bias
* Either B can be either fp32 or bf16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr);

/**
* @brief For bfloat16 precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t MLASCALL
MlasSBGemmPackBSize(size_t N, size_t K);

/**
* @brief For bfloat16 precision GEMM, convert the float matrix B
* to blfoat16 precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void MLASCALL
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
#endif

//
// Matrix/matrix multiply routines.
// C := alpha * op(A) * op(B) + beta * C
Expand Down Expand Up @@ -312,17 +425,36 @@ MlasGemm(
MLAS_THREADPOOL* ThreadPool
)
{
MLAS_SGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;

MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool);
#if defined(__aarch64__) && defined(__linux__)
if (TransA == CblasNoTrans && TransB == CblasNoTrans) {
MLAS_SBGEMM_DATA_PARAMS Data;
Data.BIsfp32 = true;
Data.AIsfp32 = true;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.C = C;
Data.ldc = N;
Data.Bias = nullptr;
Data.OutputProcessor = nullptr;

MlasSBGemmBatch(M, N, K, 1, &Data, ThreadPool);
} else
#endif
{
MLAS_SGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;

MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool);
}
}

/**
Expand Down Expand Up @@ -685,13 +817,16 @@ MlasSymmQgemmBatch(
size_t
MLASCALL
MlasGemmPackBSize(
CBLAS_TRANSPOSE TransA,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought - Do we need data to see if A is going to be transposed to determine the size of packed B ? It seems like it bloats the API unnecessarily. The same thoughts for the MlasGemmPackB(...) routine. It seems like its only real usage is to see if the Kleidi library can be used. Can this be determined separately and sent to the packing routines ? The issue with introducing this to the API is that other non-Kleidi paths must handle them appropriately if exposed in the API ?

CBLAS_TRANSPOSE TransB,
size_t N,
size_t K
);

void
MLASCALL
MlasGemmPackB(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t N,
size_t K,
Expand Down Expand Up @@ -1803,119 +1938,6 @@ MlasHalfGemmConvertPackB(
void* PackedB
);

#if defined(__aarch64__) && defined(__linux__)
/**
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
*/
bool MLASCALL
MlasBf16AccelerationSupported();

/**
* @brief Interface for bf16 gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from single precision to precision, etc.
*
* SBGEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_SBGEMM_POSTPROCESSOR
{
public:
virtual void Process(float*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_SBGEMM_POSTPROCESSOR() {}
};

/**
* @brief bfloat16 precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
* And the supplied sum tensor will be added to the tensor before activation.
*/
class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR
{
public:
MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr)
: Activation_(Activation), SumBuf_(SumBuf)
{
}

void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc)
const override;

private:
const MLAS_ACTIVATION& Activation_;
const float* SumBuf_;
};

/**
* @brief Data parameters for bfloat16 precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_SBGEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const float* Bias = nullptr; /**< address of Bias, vector size N */
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
};

/**
* @brief Bfloat16 precision Batched GEMM: C = A * B + Bias
* Either B can be either fp32 or bf16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr);

/**
* @brief For bfloat16 precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t MLASCALL
MlasSBGemmPackBSize(size_t N, size_t K);

/**
* @brief For bfloat16 precision GEMM, convert the float matrix B
* to blfoat16 precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void MLASCALL
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
#endif

/**
* @brief Indirect Depthwise convolution for fp16
* @param Input Supplies the indirect buffer for NHWC input
Expand Down
Loading