Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,10 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512;

struct MLAS_QNBIT_GEMM_DISPATCH;

extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;
const MLAS_QNBIT_GEMM_DISPATCH&
GetMlasQNBitGemmDispatchNeon(
bool InitializeWithDotSupport
);

extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

Expand Down
17 changes: 6 additions & 11 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
this->RopeDispatch = &MlasRopeDispatchNeon;
this->HGemmDispatch = &MlasHGemmDispatchNeon;
this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon;
Expand All @@ -552,22 +551,16 @@ Return Value:
// Check if the processor supports ASIMD dot product instructions.
//

bool HasDotProductInstructions;

#if defined(_WIN32)
HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
#else
// Use the cpuinfo value which is read from sysctl and has some additional special cases.
// https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379
// Note:
// Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips
// as well as failing on other ARM chips as it is an EL1 level register that requires extra
// privileges to read.
//
// uint64_t isar0_el1;
// asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :);
// HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;
HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot();
#endif
// const bool HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;

const bool HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot();

if (HasDotProductInstructions) {
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUdot;
Expand All @@ -578,6 +571,8 @@ Return Value:
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;
}

this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions);

#if defined(__linux__)
//
// Check if the processor supports ASIMD I8MM instructions.
Expand Down
50 changes: 31 additions & 19 deletions onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ Module Name:

--*/

#include "qnbitgemm_kernel_neon.h"

#include <arm_neon.h>

#include <cassert>

#include "qnbitgemm.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_q8_block.h"

namespace sqnbitgemm_neon
Expand Down Expand Up @@ -172,30 +173,41 @@ Q4BitGemmPerGemmWorkspaceAlignment(
} // namespace sqnbitgemm_neon

//
// Kernel dispatch structure definition.
// Kernel dispatch structure accessor.
//

const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
MLAS_QNBIT_GEMM_DISPATCH d;
const MLAS_QNBIT_GEMM_DISPATCH&
GetMlasQNBitGemmDispatchNeon(
bool InitializeWithDotSupport
)
{
// Note: The InitializeWithX parameters are only used in the invocation of this method that initializes the static
// MLAS_QNBIT_GEMM_DISPATCH instance.

static const MLAS_QNBIT_GEMM_DISPATCH MlasQNBitGemmDispatchNeon = [&]() {
MLAS_QNBIT_GEMM_DISPATCH d;

d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData;
d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData;

d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize;
d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment;
d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize;
d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32;
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) {
d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8;
}
d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8;
d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32;
if (InitializeWithDotSupport) {
d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8;
}
d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8;

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16;
d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16;
d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16;
d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16;
d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16;
d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16;
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64

return d;
}();
return d;
}();

return MlasQNBitGemmDispatchNeon;
}
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Module Name:
#include <cstddef>
#include <utility>

#include "mlas_qnbit.h"
#include "mlasi.h"

namespace sqnbitgemm_neon
Expand Down
Loading