diff --git a/source/backend/cpu/CPURuntime.cpp b/source/backend/cpu/CPURuntime.cpp index c2480a2c1..77d9c4934 100644 --- a/source/backend/cpu/CPURuntime.cpp +++ b/source/backend/cpu/CPURuntime.cpp @@ -30,7 +30,9 @@ // ref: https://cs.android.com/android/platform/superproject/+/master:bionic/libc/kernel/uapi/asm-arm64/asm/hwcap.h;drc=04da58f5b3bc40dbbafb4f8422aa2991479d9e1e;l=70 #define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000) #define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000) -#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002) + +#define CPUINFO_ARM_LINUX_FEATURE2_SVE2 UINT32_C(0x00000002) +#define CPUINFO_ARM_LINUX_FEATURE2_SME2 UINT64_C(0x0000002000000000) #endif #include @@ -1281,6 +1283,9 @@ static void _getInfoApple(MNNCPUInfo* cpuinfo_isa) { if (have_feature("hw.optional.arm.FEAT_I8MM")) { cpuinfo_isa->i8mm = true; } + if (have_feature("hw.optional.arm.FEAT_SME2")) { + cpuinfo_isa->sme2 = true; + } } #endif @@ -1288,6 +1293,8 @@ static void _getInfoApple(MNNCPUInfo* cpuinfo_isa) { static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) { // Use AUX to get info for linux-aarch64 uint32_t isa_features = 0; + uint64_t isa_features2 = 0; + isa_features = (uint32_t)getauxval(AT_HWCAP); if (isa_features & CPUINFO_ARM_LINUX_FEATURE_ASIMDDP) { cpuinfo_isa->dot = true; @@ -1299,10 +1306,14 @@ static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) { if (isa_features & CPUINFO_ARM_LINUX_FEATURE_I8MM) { cpuinfo_isa->i8mm = true; } - isa_features = (uint32_t)getauxval(AT_HWCAP2); - if (isa_features & CPUINFO_ARM_LINUX_FEATURE_SVE2) { + + isa_features2 = (uint64_t)getauxval(AT_HWCAP2); + if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SVE2) { cpuinfo_isa->sve2 = true; } + if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SME2) { + cpuinfo_isa->sme2 = true; + } } #endif @@ -1353,6 +1364,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) { cpuinfo_isa->fp16arith = false; cpuinfo_isa->i8mm = false; cpuinfo_isa->sve2 = false; + cpuinfo_isa->sme2 = false; // android /**Get CPU Info*/ #ifdef __linux__ @@ -1449,6 +1461,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) { cpuinfo_isa->dot = true; #endif - MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2); + MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d, sme2: %d\n", + cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2, cpuinfo_isa->sme2); return; } diff --git a/source/backend/cpu/CPURuntime.hpp b/source/backend/cpu/CPURuntime.hpp index 7155e023b..a71142d50 100644 --- a/source/backend/cpu/CPURuntime.hpp +++ b/source/backend/cpu/CPURuntime.hpp @@ -21,6 +21,7 @@ struct MNNCPUInfo { bool dot; bool i8mm; bool sve2; + bool sme2; std::vector groups; int cpuNumber = 0; }; diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 8a3576af0..b30d3705c 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -47,7 +47,9 @@ if (MNN_KLEIDIAI) endif() list(APPEND MNN_SOURCES_KLEIDIAI ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp) + list(APPEND MNN_SOURCES_KLEIDIAI ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.cpp) list(APPEND MNN_HEADERS_KLEIDIAI ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h) + list(APPEND MNN_HEADERS_KLEIDIAI ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.h) # KleidiAI include_directories( diff --git a/source/backend/cpu/arm/mnn_kleidiai.cpp b/source/backend/cpu/arm/mnn_kleidiai.cpp index dc1f9169f..65ebd5168 100644 --- a/source/backend/cpu/arm/mnn_kleidiai.cpp +++ b/source/backend/cpu/arm/mnn_kleidiai.cpp @@ -10,349 +10,257 @@ using namespace MNN; -KleidiAI *KleidiAI::instance = NULL; - -inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { - // Since we pack a float and int32 value at the end of the row, - // we must make sure that k is a multiple of 4 for memory alignment. - size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); - return kai_roundup(k, kr_sr_roundedup4); +#define FLT16_MAX 65504.0f +#define FLT16_MIN -65504.0f + +bool KleidiAI::mKaiInitialized = false; +KleidiAI *KleidiAI::mKaiInstance = NULL; +KleidiAI::StaticInfo KleidiAI::mStaticInfo; + +//Get instance. +KleidiAI& KleidiAI::getInstance(const MNNCPUInfo& gCPUInfo, bool bFP16, bool bBF16) { + if(!mKaiInstance) { + mKaiInstance = new KleidiAI; + mKaiInitialized = true; + + mStaticInfo.mFP16 = bFP16; + mStaticInfo.mBF16 = bBF16; + mStaticInfo.mDot = gCPUInfo.dot; + mStaticInfo.mI8mm = gCPUInfo.i8mm; + mStaticInfo.mSme2 = gCPUInfo.sme2; + + initKernelInfo(); + } + return *mKaiInstance; } -static void packQsi4cxps16s0Qs4cxs0s1( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, - const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { - KAI_ASSERT(num_groups == 1); - KAI_ASSERT(extra_bytes == 0); - KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT(rhs != NULL); - KAI_ASSERT(scale != NULL); - KAI_ASSERT(rhs_packed != NULL); - KAI_ASSERT(params != NULL); - KAI_ASSERT(params->rhs_zero_point == 8); - KAI_ASSERT(params->lhs_zero_point == 1); - - const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); - const size_t k_internal = kai_k_roundedup(k, kr, sr); - const size_t dst_num_rows = kai_roundup(n, nr) / nr; - const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); - const size_t block_length_in_bytes = kr / sr; - const size_t k_interleaved_v = 16U; - const size_t rhs_stride = kai_roundup(k, 2) / 2; - - for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { - uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; - - int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); - - // Initialize to zero the RHS reduction sums - memset(sums, 0, nr * sizeof(int32_t)); - - for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { - const size_t block_idx = dst_byte_idx / block_length_in_bytes; - const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; - const size_t super_block_idx = block_idx / nr; - const size_t nr_idx = block_idx % nr; - - const size_t k_adjustment = - ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; - const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; - const size_t k1_idx = k0_idx + k_interleaved_v; - const size_t n0_idx = dst_row_idx * nr + nr_idx; - - // Clamp the index to avoid out-of-bound reads - const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - - const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; - const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; - - uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; - uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; - - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; - } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte1]; - } - - // The following operations where we extract the values from the bytes - // can be also written in the following and less efficient manner: - /* - uint8_t src_x0_lo = 0; - uint8_t src_x0_hi = 0; - - if ((k0_idx % 2) == 0) { - src_x0_lo = (byte0 & 0x0F); - } else { - src_x0_lo = (byte0 >> 4); - } - - if ((k1_idx % 2) == 0) { - src_x0_hi = (byte1 & 0x0F); - } else { - src_x0_hi = (byte1 >> 4); - } - */ - const size_t shift_right_x0 = ((k0_idx + 1) % 2) * 4; - const size_t shift_right_x1 = ((k1_idx + 1) % 2) * 4; - - const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; - const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; - - sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; - - const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); - - *dst_row = dst_qs0 ^ 0x88; - dst_row += sizeof(uint8_t); - } +KleidiAI& KleidiAI::getInstance() { + if(!mKaiInstance) { + MNN_ASSERT(0); //Should never happen. + } + return *mKaiInstance; +} - // Adjust the reduction sums - for (size_t i = 0; i < nr; ++i) { - sums[i] = sums[i] * 16; - dst_row += sizeof(int32_t); - } +//Print +void KleidiAI::printInfo(AccelType type) { + if(type == AccelType::ACC_TYPE_ERROR) { + return; + } - // Adjust the scales - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; - dst_row += sizeof(float); - } + static const char * const names[] = { + "QI4_ASYM_CHNLQT", + "QI4_ASYM_BLKQT", + "QI4_SYM_CHNLQT", + "QI4_SYM_BLKQT", + "QI8_ASYM_CHNLQT", + "QI8_ASYM_BLKQT", + "QI8_SYM_CHNLQT", + "QI8_SYM_BLKQT", + "FP16", + "FP32", + "BF16", + }; + + KernelInfo *pInfo = &mStaticInfo.mKernelInfo[(size_t)type]; + if(pInfo->mKernelSupport) { + MNN_PRINT("\nKleidiAI is running! AccelType is %s. ", names[(size_t)type]); + } else { + MNN_PRINT("\nKleidiAI cannot accelerate! AccelType is %s. ", names[(size_t)type]); + } - // Set the bias - if (bias == NULL) { - memset(dst_row, 0, nr * sizeof(float)); - } else { - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - ((float*)dst_row)[i] = bias[src_row_idx]; - } - } + if(mStaticInfo.mFP16) { + MNN_PRINT("Data type is FP16.\n"); + } else if(mStaticInfo.mBF16) { + MNN_PRINT("Data type is BF16.\n"); + } else { + MNN_PRINT("Data type is FP32.\n"); } } -static void packQs4cxs16s0Qsi8cx(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, - const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { - KAI_ASSERT(num_groups == 1); - KAI_ASSERT(extra_bytes == 0); - KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT(rhs != NULL); - KAI_ASSERT(scale != NULL); - KAI_ASSERT(rhs_packed != NULL); - KAI_ASSERT(params != NULL); - KAI_ASSERT(params->rhs_zero_point == 8); - KAI_ASSERT(params->lhs_zero_point == 1); - - const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); - const size_t k_internal = kai_k_roundedup(k, kr, sr); - const size_t dst_num_rows = kai_roundup(n, nr) / nr; - const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); - const size_t block_length_in_bytes = kr / sr; - const size_t k_interleaved_v = 16U; - const size_t rhs_stride = kai_roundup(k, 2); - - for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { - uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; - - int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); - - // Initialize to zero the RHS reduction sums - memset(sums, 0, nr * sizeof(int32_t)); - - for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { - const size_t block_idx = dst_byte_idx / block_length_in_bytes; - const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; - const size_t super_block_idx = block_idx / nr; - const size_t nr_idx = block_idx % nr; - - const size_t k_adjustment = - ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; - const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; - const size_t k1_idx = k0_idx + k_interleaved_v; - const size_t n0_idx = dst_row_idx * nr + nr_idx; - - // Clamp the index to avoid out-of-bound reads - const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - - const size_t src_addr_byte0 = k0_idx + n0_valid_idx * rhs_stride; - const size_t src_addr_byte1 = k1_idx + n0_valid_idx * rhs_stride; - - int8_t byte0 = 0; - int8_t byte1 = 0; - - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; +//Init +void KleidiAI::initKernelInfo() { + for(size_t type = 0; type < static_cast(AccelType::ACC_TYPE_NUMBER); type++) { + KernelInfo *pInfo = &mStaticInfo.mKernelInfo[type]; + bool bSupport = false; + + switch(static_cast(type)) { + case AccelType::QI4_SYM_CHNLQT: + { + bSupport = (mStaticInfo.mDot && mStaticInfo.mI8mm) && (!mStaticInfo.mFP16 && !mStaticInfo.mBF16); + if(bSupport) { + KernelParam *pParam = &pInfo->mKernelParam; + pParam->mKaiMstepGemv = 1; + pParam->mKaiMstepGemm = 8; + pParam->mKaiNStep = 4; + pParam->mKaiMrGemv = 1; + pParam->mKaiMrGemm = 4; + pParam->mKaiNr = 4; + pParam->mKaiKr = 16; + pParam->mKaiSr = 2; } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte1]; - } - - sums[nr_idx] += (int32_t)byte0 + (int32_t)byte1; - - const uint8_t dst_qs0 = (byte0 + rhs_zero_point) | ((byte1 + rhs_zero_point) << 4); - - *dst_row = dst_qs0 ^ 0x88; - dst_row += sizeof(uint8_t); + break; } - - // Adjust the reduction sums - for (size_t i = 0; i < nr; ++i) { - sums[i] = sums[i] * 16; - dst_row += sizeof(int32_t); - } - - // Adjust the scales - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; - dst_row += sizeof(float); + case AccelType::QI4_ASYM_CHNLQT: + case AccelType::QI4_ASYM_BLKQT: + case AccelType::QI4_SYM_BLKQT: + case AccelType::QI8_ASYM_CHNLQT: + case AccelType::QI8_ASYM_BLKQT: + case AccelType::QI8_SYM_CHNLQT: + case AccelType::QI8_SYM_BLKQT: + case AccelType::FP16: + case AccelType::FP32: + case AccelType::BF16: + break; + default: + MNN_ASSERT(0); + break; } - // Set the bias - if (bias == NULL) { - memset(dst_row, 0, nr * sizeof(float)); - } else { - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - ((float*)dst_row)[i] = bias[src_row_idx]; - } - } + pInfo->mKernelSupport = bSupport; } } -void KleidiAI::packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize) { - if(rowNum == 1) { - return; - } - - const size_t tmp_size = rowNum * rowSize * sizeof(float); - uint8_t *tmpBuffer = new uint8_t[tmp_size]; - memcpy(tmpBuffer, data, tmp_size); - - const float *src = (const float *)tmpBuffer; - float *dst = (float *)data; - - size_t blockNum = rowSize / 4; - size_t blockSize = 4 * sizeof(float); - - for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { - const float *rowSrc = src + blockIndex * 4; - for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { - memcpy(dst, rowSrc, blockSize); - dst += 4; - rowSrc += rowSize; - } +//Get Info +KleidiAI::AccelType KleidiAI::getQIntAccelType(size_t bits, bool bAsymmetric, size_t blockSize) { + static std::map infoMap = { + {KleidiAI::QIntInfo(4, true, 0), KleidiAI::AccelType::QI4_ASYM_CHNLQT}, + {KleidiAI::QIntInfo(4, true, -1), KleidiAI::AccelType::QI4_ASYM_BLKQT}, + {KleidiAI::QIntInfo(4, false, 0), KleidiAI::AccelType::QI4_SYM_CHNLQT}, + {KleidiAI::QIntInfo(4, false, -1), KleidiAI::AccelType::QI4_SYM_BLKQT}, + {KleidiAI::QIntInfo(8, true, 0), KleidiAI::AccelType::QI8_ASYM_CHNLQT}, + {KleidiAI::QIntInfo(8, true, -1), KleidiAI::AccelType::QI8_ASYM_BLKQT}, + {KleidiAI::QIntInfo(8, false, 0), KleidiAI::AccelType::QI8_SYM_CHNLQT}, + {KleidiAI::QIntInfo(8, false, -1), KleidiAI::AccelType::QI8_SYM_BLKQT}, + }; + + QIntInfo info(bits, bAsymmetric, blockSize); + auto it = infoMap.find(info); + if(it != infoMap.end()) { + return it->second; + } else { + return AccelType::ACC_TYPE_ERROR; } - - delete[] tmpBuffer; } -void KleidiAI::packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize) { - if(rowNum == 1) { - return; +//Lhs +size_t KleidiAI::getLhsQuantedPackedSize(AccelType type, size_t m, size_t k, size_t bl) { + MNN_ASSERT(type >= AccelType::QINT && type <= AccelType::QINT_END); + + switch(type) { + case AccelType::QI4_SYM_CHNLQT: + return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, getMr(type, m), getKr(type), getSr(type)); + default: + MNN_ASSERT(0); } - const size_t tmp_size = rowNum * rowSize * sizeof(float); - uint8_t *tmpBuffer = new uint8_t[tmp_size]; - memcpy(tmpBuffer, data, tmp_size); - - const float *src = (const float *)tmpBuffer; - float *dst = (float *)data; + return 0; +} - size_t blockNum = rowSize / 4; - size_t blockSize = 4 * sizeof(float); +size_t KleidiAI::getLhsQuantedPackedOffset(AccelType type, size_t m, size_t mIdx, size_t k, size_t bl) { + MNN_ASSERT(type >= AccelType::QINT && type <= AccelType::QINT_END); - for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { - const float *rowSrc = src + blockIndex * 4 * rowNum; - float *block_dst = dst + blockIndex * 4; - for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { - memcpy(block_dst, rowSrc, blockSize); - block_dst += rowSize; - rowSrc += 4; - } + if(mIdx == 0) { + return 0; } - delete[] tmpBuffer; -} - -//Set info -void KleidiAI::setEnable(bool enable) { - mKaiInfo.kaiEnable = enable; - if(canAccelerate()) { - MNN_PRINT("\nKleidiAI is running!\n"); + switch(type) { + case AccelType::QI4_SYM_CHNLQT: + return kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(mIdx, k, getMr(type, m), getKr(type), getSr(type)); + default: + MNN_ASSERT(0); } -} -void KleidiAI::setModelAsymmetric(bool bAsymmetric) { - mKaiInfo.asymmetric = bAsymmetric; - if(canAccelerate()) { - MNN_PRINT("\nKleidiAI is running!\n"); - } + return 0; } -//Lhs -size_t KleidiAI::getLhsQuantedPackedSize(size_t m, size_t k) { - return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, getMr(m), getKr(), getSr()); +void KleidiAI::runLhsPack(AccelType type, size_t m, size_t k, size_t mIdx, const void* lhs, size_t lhsStride, void* lhsPacked) +{ + MNN_ASSERT(type >= AccelType::FLOAT && type <= AccelType::FLOAT_END); + //For float ukernels, Not support yet. } -size_t KleidiAI::getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k) { - return mIdx == 0 ? 0 : kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(mIdx, k, getMr(m), getKr(), getSr()); -} +void KleidiAI::runLhsQuantPack(AccelType type, size_t m, size_t k, size_t bl, size_t mr, const void* lhs, void* lhsQuantedPacked) { + MNN_ASSERT(type >= AccelType::QINT && type <= AccelType::QINT_END); -void KleidiAI::runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked) { - kai_run_lhs_quant_pack_qai8dxp_f32(m, k, mr, getKr(), getSr(), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); + switch(type) { + case AccelType::QI4_SYM_CHNLQT: + kai_run_lhs_quant_pack_qai8dxp_f32(m, k, mr, getKr(type), getSr(type), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); + break; + default: + MNN_ASSERT(0); + } } //Rhs -size_t KleidiAI::getRhsPackedSize(size_t n, size_t k) { - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, getNr(), getKr(), getSr()); +size_t KleidiAI::getRhsPackedSize(AccelType type, size_t n, size_t k, size_t bl) { + switch(type) { + case AccelType::QI4_SYM_CHNLQT: + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, getNr(type), getKr(type), getSr(type)); + default: + MNN_ASSERT(0); + return 0; + } } -size_t KleidiAI::getRhsPackedOffset(size_t nIdx, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); +size_t KleidiAI::getRhsPackedOffset(AccelType type, size_t nIdx, size_t k, size_t bl) { + if(nIdx == 0) { + return 0; + } + + switch(type) { + case AccelType::QI4_SYM_CHNLQT: + return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(nIdx, k, getNr(type), getKr(type), getSr(type)); + default: + MNN_ASSERT(0); + return 0; + } } -void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4) { - struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - if(!packedInt4) { - packQs4cxs16s0Qsi8cx(1, n, k, getNr(), getKr(), getSr(), - (const uint8_t *)rhs, - (const float *)bias, (const float *)scale, - rhsPacked, - 0, ¶ms); - } else { - packQsi4cxps16s0Qs4cxs0s1(1, n, k, getNr(), getKr(), getSr(), - (const uint8_t *)rhs, - (const float *)bias, (const float *)scale, - rhsPacked, - 0, ¶ms); +void KleidiAI::runRhsPack(AccelType type, size_t numGroups, size_t n, size_t k, size_t bl, size_t rhsStride, + const void* rhs, const void* scale, const void* zeroPoint, const void* bias, + void* rhsPacked, bool packedQ4) { + switch(type) { + case AccelType::QI4_SYM_CHNLQT: + { + KleidiAIUtil::rhsPackParamCommon paramCommon; + if(packedQ4) { + KleidiAIUtil::packQsi4cxps16s0Qs4cxs0s1(numGroups, n, k, getNr(type), getKr(type), getSr(type), + (const uint8_t *)rhs, (const float *)bias, (const float *)scale, + rhsPacked, 0, ¶mCommon); + } else { + KleidiAIUtil::packQsi4cxps16s0Qs4cx(numGroups, n, k, getNr(type), getKr(type), getSr(type), + (const uint8_t *)rhs, (const float *)bias, (const float *)scale, + rhsPacked, 0, ¶mCommon); + } + break; + } + default: + MNN_ASSERT(0); } } //Matmul -void KleidiAI::runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst) { - if(m == 1) { //dotprod - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k, - (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); - } else { //i8mm - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k, - (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); +void KleidiAI::runMatmul(AccelType type, size_t m, size_t n, size_t k, size_t bl, + const void* lhsPacked, const void* rhsPacked, void* dst, + size_t dstStrideRow, size_t dstStrideCol) { + switch(type) { + case AccelType::QI4_SYM_CHNLQT: + { + const float scalar_max = FLT_MAX; + const float scalar_min = -scalar_max; + if(m == 1) { + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k, + (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, + dstStrideRow, dstStrideCol, scalar_min, scalar_max); + } else { + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k, + (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, + dstStrideRow, dstStrideCol, scalar_min, scalar_max); + } + break; + } + default: + MNN_ASSERT(0); } } diff --git a/source/backend/cpu/arm/mnn_kleidiai.h b/source/backend/cpu/arm/mnn_kleidiai.h index 38cdce230..06343c844 100644 --- a/source/backend/cpu/arm/mnn_kleidiai.h +++ b/source/backend/cpu/arm/mnn_kleidiai.h @@ -7,119 +7,226 @@ #pragma once #include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" - -#include "kai_common.h" +#include "core/Backend.hpp" +#include "core/Execution.hpp" +#include "core/TensorUtils.hpp" +#include "core/ConvolutionCommon.hpp" +#include "backend/cpu/CPUBackend.hpp" +#include "backend/cpu/CPURuntime.hpp" +#include "backend/cpu/compute/CommonOptFunction.h" + +#include "mnn_kleidiai_util.h" namespace MNN { class KleidiAI { public: - static KleidiAI &getInstance(bool bAsymmetric, bool acthalf, bool blockwise) { - if(!instance) { - instance = new KleidiAI(bAsymmetric, acthalf, blockwise); + // =================================================================== + // Enum definition + + enum class AccelType { + /* + ASYM/SYM: Asymmetric/symmetric; + CHNLQT/BLKQT: channel wise/block wise; + */ + QINT = 0, + QI4_ASYM_CHNLQT = QINT, + QI4_ASYM_BLKQT, + QI4_SYM_CHNLQT, + QI4_SYM_BLKQT, + QI8_ASYM_CHNLQT, + QI8_ASYM_BLKQT, + QI8_SYM_CHNLQT, + QI8_SYM_BLKQT, + QINT_END = QI8_SYM_BLKQT, + + FLOAT, + FP16 = FLOAT, + FP32, + BF16, + FLOAT_END = BF16, + + ACC_TYPE_NUMBER, + ACC_TYPE_ERROR = ACC_TYPE_NUMBER + }; + + // =================================================================== + // Some necessary data structures + typedef struct KernelParam { + size_t mKaiMstepGemv = 0; + size_t mKaiMstepGemm = 0; + size_t mKaiNStep = 0; + + size_t mKaiMrGemv = 0; + size_t mKaiMrGemm = 0; + size_t mKaiNr = 0; + size_t mKaiKr = 0; + size_t mKaiSr = 0; + } KernelParam; + + typedef struct KernelInfo { + bool mKernelSupport = false; + KernelParam mKernelParam; + } KernelInfo; + + typedef struct StaticInfo { + bool mFP16 = false; //fp16 or fp32. + bool mBF16 = false; //bf16 or fp32. + + bool mDot = false; + bool mI8mm = false; + bool mSme2 = false; + + KernelInfo mKernelInfo[(size_t)AccelType::ACC_TYPE_NUMBER]; + } StaticInfo; + + + typedef struct QIntInfo { + size_t mBits; + bool mAsymmetric; //Asymmetric quantized model. + size_t mBlockSize; //0: Per channel quant; others: Per block quant. + + QIntInfo(size_t bits = 4, bool asymmetric = false, size_t blockSize = 0) { + mBits = bits; + mAsymmetric = asymmetric; + mBlockSize = blockSize; } - return *instance; - } - static KleidiAI &getInstance() { - if(!instance) { - instance = new KleidiAI; + bool operator<(const QIntInfo& rhs) const { + if(mBits != rhs.mBits) { + return mBits < rhs.mBits; + } + + if(mAsymmetric != rhs.mAsymmetric) { + return mAsymmetric < rhs.mAsymmetric; + } + + bool lhsPerChannel = mBlockSize == 0 ? true : false; + bool rhsPerChannel = rhs.mBlockSize == 0 ? true : false; + return lhsPerChannel < rhsPerChannel; } - return *instance; - } + } QIntInfo; + + // =================================================================== + + //Public static members. + static bool mKaiInitialized; + + //Get instance. + static KleidiAI &getInstance(const MNNCPUInfo& gCPUInfo, bool bFP16, bool bBF16); + static KleidiAI &getInstance(); + static void initKernelInfo(); ~KleidiAI() {} - typedef struct KaiInfo { - bool kaiEnable = false; - bool asymmetric = false; //Asymmetric quantized model. - bool acthalf = false; // activation half precision. - bool blockwise = false; // weight quant using block wise. - bool dot = false; //CPU support sdot. - bool i8mm = false; //CPU support i8mm. - } KaiInfo; - - //Kai util - void packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize); - void packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize); - - //Set info - void setEnable(bool enable); - void setModelAsymmetric(bool bAsymmetric); - - //Check - bool canAccelerate() { - return (mKaiInfo.kaiEnable && mKaiInfo.dot && mKaiInfo.i8mm && - !mKaiInfo.asymmetric && !mKaiInfo.acthalf && !mKaiInfo.blockwise); - } + void printInfo(AccelType type); + + //Check and set + bool canAccelerate(); + bool canAccelerate(AccelType type); + bool isLoaded(AccelType type); + void setLoaded(AccelType type) { mLoaded[(size_t)type] = true; } + bool isLinear() { return mLinear; } + void setLinear(bool bLinear) { mLinear = bLinear; } //Get info - size_t getMr(size_t m = 1) { return (m == 1) ? mKaiMrDotprod : mKaiMrI8mm; } - size_t getNr() { return mKaiNr; } - size_t getKr() { return mKaiKr; } - size_t getSr() { return mKaiSr; } - size_t getMStep(size_t m = 1) { return (m == 1) ? mKaiMstepDotprod : mKaiMstepI8mm; } - size_t getNStep() { return mKaiNStep; } - size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) { return kai_roundup((totalVec + totalThread - 1) / totalThread, minStep); } + static AccelType getQIntAccelType(size_t bits, bool bAsymmetric, size_t blockSize); + size_t getMr(AccelType type, size_t m = 1); + size_t getNr(AccelType type); + size_t getKr(AccelType type); + size_t getSr(AccelType type); + size_t getMStep(AccelType type, size_t m = 1); + size_t getNStep(AccelType type); + size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep); + //Get Static info + bool isFP16() { return mStaticInfo.mFP16; } + bool isBF16() { return mStaticInfo.mBF16; } + bool isHalf() { return mStaticInfo.mFP16 || mStaticInfo.mBF16; } //Lhs - size_t getLhsQuantedPackedSize(size_t m, size_t k); - size_t getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k); - void runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked); + size_t getLhsQuantedPackedSize(AccelType type, size_t m, size_t k, size_t bl); + size_t getLhsQuantedPackedOffset(AccelType type, size_t m, size_t mIdx, size_t k, size_t bl); + void runLhsPack(AccelType type, size_t m, size_t k, size_t mIdx, const void* lhs, size_t lhsStride, void* lhsPacked); + void runLhsQuantPack(AccelType type, size_t m, size_t k, size_t bl, size_t mr, const void* lhs, void* lhsQuantedPacked); //Rhs - size_t getRhsPackedSize(size_t n, size_t k); - size_t getRhsPackedOffset(size_t nIdx, size_t k); - void runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4 = false); - + size_t getRhsPackedSize(AccelType type, size_t n, size_t k, size_t bl); + size_t getRhsPackedOffset(AccelType type, size_t nIdx, size_t k, size_t bl); + void runRhsPack(AccelType type, size_t numGroups, size_t n, size_t k, size_t bl, size_t rhsStride, + const void* rhs, const void* scale, const void* zeroPoint, const void* bias, + void* rhsPacked, bool packedQ4); //Dst - size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n) { return (nIdx * sizeof(float)) + mIdx * (n * sizeof(float)); } + size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n, size_t elementSize) { return (nIdx * elementSize) + mIdx * (n * elementSize); } //Matmul - void runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst); + void runMatmul(AccelType type, size_t m, size_t n, size_t k, size_t bl, + const void* lhsPacked, const void* rhsPacked, void* dst, + size_t dstStrideRow, size_t dstStrideCol); private: - KleidiAI(bool bAsymmetric = false, bool acthalf = false, bool blockwise = false) { - const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); - mKaiInfo.dot = gCPUInfo.dot; - mKaiInfo.i8mm = gCPUInfo.i8mm; - mKaiInfo.kaiEnable = true; - mKaiInfo.asymmetric = bAsymmetric; - mKaiInfo.acthalf = acthalf; - mKaiInfo.blockwise = blockwise; - - if(canAccelerate()) { - MNN_PRINT("\nKleidiAI is running!\n"); + KleidiAI() {} + + static KleidiAI *mKaiInstance; + //Static info, never change after construct. + static StaticInfo mStaticInfo; + //Status, will change while pipeline is running. + bool mLoaded[(size_t)AccelType::ACC_TYPE_NUMBER] = { false }; + bool mLinear = false; //All pipeline format has been set as NCHW. + }; + + // =================================================================== + // Inline functions + inline bool KleidiAI::canAccelerate() { + for(size_t type = 0; type < (size_t)AccelType::ACC_TYPE_NUMBER; type++) { + if(mStaticInfo.mKernelInfo[(size_t)type].mKernelSupport && isLoaded(static_cast(type))) { + return true; } } + return false; + } - static KleidiAI *instance; - KaiInfo mKaiInfo; - - const size_t mKaiMstepDotprod = 1; - const size_t mKaiMstepI8mm = 8; - const size_t mKaiNStep = 4; - - const size_t mKaiMrDotprod = 1; - const size_t mKaiMrI8mm = 4; - const size_t mKaiNr = 4; - const size_t mKaiKr = 16; - const size_t mKaiSr = 2; - }; + inline bool KleidiAI::canAccelerate(AccelType type) { + if(type >= AccelType::ACC_TYPE_ERROR) { + return false; + } + return mStaticInfo.mKernelInfo[(size_t)type].mKernelSupport; + } + + inline bool KleidiAI::isLoaded(AccelType type) { + MNN_ASSERT(type < AccelType::ACC_TYPE_NUMBER); + return mLoaded[(size_t)type]; + } + + inline size_t KleidiAI::getMr(AccelType type, size_t m) { + KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam; + return (m == 1) ? pParam->mKaiMrGemv : pParam->mKaiMrGemm; + } + + inline size_t KleidiAI::getNr(AccelType type) { + KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam; + return pParam->mKaiNr; + } + + inline size_t KleidiAI::getKr(AccelType type) { + KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam; + return pParam->mKaiKr; + } + + inline size_t KleidiAI::getSr(AccelType type) { + KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam; + return pParam->mKaiSr; + } + + inline size_t KleidiAI::getMStep(AccelType type, size_t m) { + KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam; + return (m == 1) ? pParam->mKaiMstepGemv : pParam->mKaiMstepGemm; + } + + inline size_t KleidiAI::getNStep(AccelType type) { + KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam; + return pParam->mKaiNStep; + } + + inline size_t KleidiAI::getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) { + return kai_roundup((totalVec + totalThread - 1) / totalThread, minStep); + } } \ No newline at end of file diff --git a/source/backend/cpu/arm/mnn_kleidiai_util.cpp b/source/backend/cpu/arm/mnn_kleidiai_util.cpp new file mode 100644 index 000000000..8137b427e --- /dev/null +++ b/source/backend/cpu/arm/mnn_kleidiai_util.cpp @@ -0,0 +1,328 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "mnn_kleidiai_util.h" + +using namespace MNN; + +static const size_t kai_num_bytes_adder_rhs = 4; //sizeof(int32_t) or sizeof(float) +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_num_bytes_per_block(size_t bl) { + return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_adder_rhs; +} + +inline static size_t kai_rhs_packed_stride_q4c32p(size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kr) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); + + return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + // multiple of 2 because 2 elements in a byte + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_adder_rhs + kai_num_bytes_bias); +} + +void KleidiAIUtil::transferNCHWToNC4HW4(float* src, float* dst, size_t rowNum, size_t rowSize) { + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(dst, rowSrc, blockSize); + dst += 4; + rowSrc += rowSize; + } + } +} + +void KleidiAIUtil::transferNCHWToNC4HW4(__fp16* src, __fp16* dst, size_t rowNum, size_t rowSize) { + size_t blockNum = rowSize / 8; + size_t blockSize = 8 * sizeof(__fp16); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const __fp16 *rowSrc = src + blockIndex * 8; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(dst, rowSrc, blockSize); + dst += 8; + rowSrc += rowSize; + } + } +} + +void KleidiAIUtil::transferNC4HW4ToNCHW(float* src, float* dst, size_t rowNum, size_t rowSize) { + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4 * rowNum; + float *block_dst = dst + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(block_dst, rowSrc, blockSize); + block_dst += rowSize; + rowSrc += 4; + } + } +} + +void KleidiAIUtil::transferNC4HW4ToNCHW(__fp16* src, __fp16* dst, size_t rowNum, size_t rowSize) { + size_t blockNum = rowSize / 8; + size_t blockSize = 8 * sizeof(__fp16); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const __fp16 *rowSrc = src + blockIndex * 8 * rowNum; + __fp16 *block_dst = dst + blockIndex * 8; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(block_dst, rowSrc, blockSize); + block_dst += rowSize; + rowSrc += 8; + } + } +} + +// Rhs pack functions for matmul_clamp_f32_qai8dxp_qsi4cxp. +void KleidiAIUtil::packQsi4cxps16s0Qs4cxs0s1( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params = (kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params *)paramsCommon; + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = ((k0_idx + 1) % 2) * 4; + const size_t shift_right_x1 = ((k1_idx + 1) % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} + +void KleidiAIUtil::packQsi4cxps16s0Qs4cx( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params = (kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params *)paramsCommon; + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2); + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = k0_idx + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = k1_idx + n0_valid_idx * rhs_stride; + + int8_t byte0 = 0; + int8_t byte1 = 0; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + sums[nr_idx] += (int32_t)byte0 + (int32_t)byte1; + + const uint8_t dst_qs0 = (byte0 + rhs_zero_point) | ((byte1 + rhs_zero_point) << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} \ No newline at end of file diff --git a/source/backend/cpu/arm/mnn_kleidiai_util.h b/source/backend/cpu/arm/mnn_kleidiai_util.h new file mode 100644 index 000000000..1af5375c8 --- /dev/null +++ b/source/backend/cpu/arm/mnn_kleidiai_util.h @@ -0,0 +1,63 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include + +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" + +#include "kai_common.h" + +namespace MNN { + class KleidiAIUtil { + public: + struct rhsPackParamCommon { + int8_t mLhsZeroPoint = 1; + uint8_t mRhsZeroPoint = 8; + }; + + static void transferNCHWToNC4HW4(float* src, float* dst, size_t rowNum, size_t rowSize); + static void transferNCHWToNC4HW4(__fp16* src, __fp16* dst, size_t rowNum, size_t rowSize); + static void transferNC4HW4ToNCHW(float* src, float* dst, size_t rowNum, size_t rowSize); + static void transferNC4HW4ToNCHW(__fp16* src, __fp16* dst, size_t rowNum, size_t rowSize); + + /// Rhs pack functions for matmul_clamp_f32_qai8dxp_qsi4cxp. + static void packQsi4cxps16s0Qs4cxs0s1( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon); // + + static void packQsi4cxps16s0Qs4cx( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon); // +}; +} \ No newline at end of file diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index 1867fa3fd..91088258e 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -279,36 +279,65 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int ic = convOp->common()->inputCount(); bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic); bool useCachedMmap = backend->getRuntime()->hint().useCachedMmap > 1; + #ifdef MNN_KLEIDIAI_ENABLED - bool half_act = gcore->bytes == 2; - int biasSize = mResourceInt8->mOriginBias->size(); - int alphaSize = mResourceInt8->mOriginScale->size(); - bool blockwise = (biasSize * 2) != alphaSize; - KleidiAI kai = KleidiAI::getInstance(quanCommon->asymmetric, half_act, blockwise); - if(quanCommon->canUseInt4 && kai.canAccelerate()) { - int n = oc; - int k = ic; - int packedWeightSize = kai.getRhsPackedSize(n, k); - - //Alloc packed weight tensor. - mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); - bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); + if(quanCommon->canUseInt4) { + bool bFP16 = gcore->bytes == 2 ? true : false; + bool bAsym = quanCommon->asymmetric; + size_t blkSize = mBlockNum == 1 ? 0 : ic / mBlockNum; + KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize); + + if(!KleidiAI::mKaiInitialized) { + KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo(), bFP16, false); + } - if (!success) { - MNN_ERROR("Out of static memory!\n"); - return; + KleidiAI& kai = KleidiAI::getInstance(); + if(!kai.isLoaded(accelType)) { + kai.setLoaded(accelType); + kai.printInfo(accelType); } - //Run rhs pack. - kai.runRhsPack(n, k, (uint8_t*)quanCommon->weight.get(), - mResourceInt8->mOriginScale->host(), - mResourceInt8->mOriginBias->host(), - mResourceInt8->mWeightInt8->host(), - directReadInt4weight); + if(kai.canAccelerate(accelType)) { + mAccelType = accelType; + int n = oc; + int k = ic; + int packedWeightSize = kai.getRhsPackedSize(mAccelType, n, k, blkSize); - return; - } + //Alloc packed weight tensor. + mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); + bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); + + if (!success) { + MNN_ERROR("Out of static memory!\n"); + return; + } + size_t paraNum = blockNum * ROUND_UP(oc, pack); + float *scalePtr = mResourceInt8->mOriginScale->host(); + float *zeroPtr = mResourceInt8->mOriginScale->host() + paraNum; + float *biasPtr = mResourceInt8->mOriginBias->host(); + //Reload some parameters to fit ukernels' layout. + auto quanInfoPtr = quanCommon->alpha.get(); + if(bAsym) { + for(int i = 0; i < paraNum; i++) { + zeroPtr[i] = quanInfoPtr[i * 2]; + scalePtr[i] = quanInfoPtr[i * 2 + 1]; + } + } else { + if(blkSize != 0) { + memcpy(scalePtr, (uint8_t*)quanInfoPtr, paraNum * sizeof(float)); + } + } + + //Run rhs pack. + auto weightPackedData = mResourceInt8->mWeightInt8->host(); + kai.runRhsPack(mAccelType, 1, n, k, blkSize, 0/*unused*/, + (uint8_t*)quanCommon->weight.get(), + (const void*)scalePtr, (const void*)zeroPtr, (const void*)biasPtr, + weightPackedData, directReadInt4weight); + return; + } + } #endif if (quanCommon->canUseInt4 && directReadInt4weight) { @@ -522,6 +551,9 @@ bool DenseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** if (!exe->valid()) { return false; } +#ifdef MNN_KLEIDIAI_ENABLED + exe->mAccelType = this->mAccelType; +#endif *dst = exe; return true; } @@ -549,11 +581,25 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input #ifdef MNN_KLEIDIAI_ENABLED KleidiAI& kai = KleidiAI::getInstance(); - if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { - int batch = inputs[0]->batch(); - int channel = inputs[0]->channel(); + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate(mAccelType)) { + MNN_ASSERT(kai.isLoaded(mAccelType)); + const size_t m = inputs[0]->batch(); //lhs vector number. + const size_t n = outputs[0]->channel(); //rhs vector number. + const size_t k = inputs[0]->channel(); //vector size. + const size_t blkSize = mBlockNum == 1 ? 0 : k / mBlockNum; + + int packedSize = kai.getLhsQuantedPackedSize(mAccelType, m, k, blkSize); + int elementSize = kai.isHalf() ? sizeof(__fp16) : sizeof(float); + if(m > 1 && !kai.isLinear()) { + int srcSize = m * k * elementSize; + int dstSize = m * n * elementSize; + int extraSize = srcSize > dstSize ? srcSize : dstSize; + packedSize += extraSize; + } - int packedSize = kai.getLhsQuantedPackedSize(batch, channel); + //Split mTempIm2ColBuffer as two parts for linear/tile transfer: + //Part0: Lhs_packed. + //Part1: Lhs/Dst before transfer. mTempIm2ColBuffer.reset(Tensor::createDevice({packedSize})); bool success = backend()->onAcquireBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); if (!success) { @@ -740,36 +786,55 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu #ifdef MNN_KLEIDIAI_ENABLED KleidiAI& kai = KleidiAI::getInstance(); - if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate(mAccelType)) { + MNN_ASSERT(kai.isLoaded(mAccelType)); const size_t m = input->batch(); //lhs vector number. const size_t n = output->channel(); //rhs vector number. const size_t k = input->channel(); //vector size. + const size_t blkSize = mBlockNum == 1 ? 0 : k / mBlockNum; + + bool bHalf = kai.isHalf(); + size_t elementSize = bHalf ? sizeof(__fp16) : sizeof(float); + size_t lhsPackedSize = kai.getLhsQuantedPackedSize(mAccelType, m, k, blkSize); auto lhs = input->host(); auto lhsPacked = mTempIm2ColBuffer->host(); auto rhsPacked = mResourceInt8->mWeightInt8->host(); auto dst = output->host(); + uint8_t *linearLhs, *linearDst; + if(m > 1 && !kai.isLinear()) { + linearLhs = (uint8_t *)lhsPacked + lhsPackedSize; + linearDst = linearLhs; + } else { + linearLhs = lhs; + linearDst = dst; + } + int threadNum = static_cast(backend())->threadNumber(); int threadNeed, vecPerThread; -#if !KAI_CONV_NCHW_IN_OUT - kai.packNC4HW4ToNCHW((float *)lhs, m, k); -#endif - //Dynamic quant pack lhs. if(m == 1) { - kai.runLhsQuantPack(1, k, 1, lhs, lhsPacked); + kai.runLhsQuantPack(mAccelType, 1, k, blkSize, 1, linearLhs, lhsPacked); } else { - vecPerThread = kai.getVecNumPerThread(m, threadNum, kai.getMr(m)); + if(!kai.isLinear()) { + if(bHalf) { + KleidiAIUtil::transferNC4HW4ToNCHW((__fp16 *)lhs, (__fp16 *)linearLhs, m, k); + } else { + KleidiAIUtil::transferNC4HW4ToNCHW((float *)lhs, (float *)linearLhs, m, k); + } + } + + vecPerThread = kai.getVecNumPerThread(m, threadNum, kai.getMr(mAccelType, m)); threadNeed = m % vecPerThread == 0 ? m / vecPerThread : (m / vecPerThread + 1); - size_t srcStride = vecPerThread * k * sizeof(float); + size_t srcStride = vecPerThread * k * elementSize; auto BatchDynamicQuant = [=, &kai](int tId) { - auto threadSrc = lhs + tId * srcStride; - auto threadDst = lhsPacked + kai.getLhsQuantedPackedOffset(m, tId * vecPerThread, k); + auto threadSrc = linearLhs + tId * srcStride; + auto threadDst = lhsPacked + kai.getLhsQuantedPackedOffset(mAccelType, m, tId * vecPerThread, k, blkSize); int vecNum = (tId == threadNeed - 1) ? (m - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. - kai.runLhsQuantPack(vecNum, k, kai.getMr(m), threadSrc, threadDst); + kai.runLhsQuantPack(mAccelType, vecNum, k, blkSize, kai.getMr(mAccelType, m), threadSrc, threadDst); }; MNN_CONCURRENCY_BEGIN(tId, threadNeed) { @@ -778,15 +843,14 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu MNN_CONCURRENCY_END(); } - //Run matmul. - vecPerThread = kai.getVecNumPerThread(n, threadNum, kai.getNStep()); + vecPerThread = kai.getVecNumPerThread(n, threadNum, kai.getNStep(mAccelType)); threadNeed = n % vecPerThread == 0 ? n / vecPerThread : (n / vecPerThread + 1); auto ThreadFunction = [=, &kai](int tId) { - auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * vecPerThread, k); - auto threadDst = dst + kai.getDstOffset(0, tId * vecPerThread, n); + auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(mAccelType, tId * vecPerThread, k, blkSize); + auto threadDst = linearDst + kai.getDstOffset(0, tId * vecPerThread, n, elementSize); int vecNum = (tId == threadNeed - 1) ? (n - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. - kai.runMatmul(m, vecNum, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); + kai.runMatmul(mAccelType, m, vecNum, k, blkSize, lhsPacked, threadRhsPacked, threadDst, n * elementSize, elementSize); }; MNN_CONCURRENCY_BEGIN(tId, threadNeed) { @@ -794,9 +858,13 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } MNN_CONCURRENCY_END(); -#if !KAI_CONV_NCHW_IN_OUT - kai.packNCHWToNC4HW4((float *)dst, m, n); -#endif + if(m > 1 && !kai.isLinear()) { + if(bHalf) { + KleidiAIUtil::transferNCHWToNC4HW4((__fp16 *)linearDst, (__fp16 *)dst, m, n); + } else { + KleidiAIUtil::transferNCHWToNC4HW4((float *)linearDst, (float *)dst, m, n); + } + } return NO_ERROR; } diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp index 6c46b9161..54dcfab08 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp @@ -78,6 +78,9 @@ class DenseConvInt8TiledExecutor : public ConvInt8TiledExecutor { int mOcPerThread; bool mSplitByOc; bool mUseBatchQuan; +#ifdef MNN_KLEIDIAI_ENABLED + KleidiAI::AccelType mAccelType = KleidiAI::AccelType::ACC_TYPE_NUMBER; +#endif }; } // namespace MNN diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 46b0e9821..45722c2ce 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -22,9 +22,9 @@ #ifdef MNN_KLEIDIAI_ENABLED #include "../backend/cpu/arm/mnn_kleidiai.h" /** - * Set DenseConvInt8TiledExecutor's input/output tensor format: - * KAI_CONV_NCHW_IN_OUT = 1: format will be NCHW, skip pack/unpack functions. - * KAI_CONV_NCHW_IN_OUT = 0: format will be NC4HW4, need pack/unpack functions to fit kleidiAI ukernel. + * Set Convolution's input/output tensor format: + * 1: format will be NCHW, skip pack/unpack functions. + * 0: format will be NC4HW4, need pack/unpack functions to fit kleidiAI ukernel. **/ #define KAI_CONV_NCHW_IN_OUT 1 #endif diff --git a/source/geometry/GeometryConvUtils.cpp b/source/geometry/GeometryConvUtils.cpp index 21670bd24..85ad7a8fb 100644 --- a/source/geometry/GeometryConvUtils.cpp +++ b/source/geometry/GeometryConvUtils.cpp @@ -248,7 +248,12 @@ std::shared_ptr GeometryConvUtils::im2Col(Tensor* im2Col, Tensor* input, } bool GeometryConvUtils::computeSingle(const Op* op, const std::vector& inputs, const std::vector& outputs, GeometryComputer::Context& context, CommandBuffer& res) { #if KAI_CONV_NCHW_IN_OUT - if(KleidiAI::getInstance().canAccelerate()) { + KleidiAI& kai = KleidiAI::getInstance(); + if(kai.canAccelerate()) { + TensorUtils::getDescribe(inputs[0])->dimensionFormat = MNN_DATA_FORMAT_NCHW; + TensorUtils::getDescribe(outputs[0])->dimensionFormat = MNN_DATA_FORMAT_NCHW; + kai.setLinear(true); + std::shared_ptr cmd(new Command); cmd->op = op; cmd->inputs = std::move(inputs); diff --git a/source/shape/ShapeTensorConvert.cpp b/source/shape/ShapeTensorConvert.cpp index 899b9410b..a3d2035e0 100644 --- a/source/shape/ShapeTensorConvert.cpp +++ b/source/shape/ShapeTensorConvert.cpp @@ -23,11 +23,6 @@ class TensorConvertSizeComputer : public SizeComputer { sourceFmt = MNN_DATA_FORMAT_NCHW; } auto destFmt = info->dest(); -#if KAI_CONV_NCHW_IN_OUT - if(KleidiAI::getInstance().canAccelerate()) { - destFmt = MNN_DATA_FORMAT_NCHW; - } -#endif TensorUtils::getDescribe(outputs[0])->dimensionFormat = destFmt; if (destFmt == MNN_DATA_FORMAT_NC4HW4) { destFmt = MNN_DATA_FORMAT_NCHW;