Skip to content

Commit

Permalink
Refine some functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
xhzheng1895 committed Jan 9, 2025
1 parent 82efb15 commit 73344cd
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 33 deletions.
20 changes: 15 additions & 5 deletions source/backend/cpu/arm/mnn_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ KleidiAI::AccelType KleidiAI::getQIntAccelType(size_t bits, bool bAsymmetric, si
//Lhs
size_t KleidiAI::getLhsQuantedPackedSize(AccelType type, size_t m, size_t k, size_t bl) {
MNN_ASSERT(type >= AccelType::QINT && type <= AccelType::QINT_END);
KAI_UNUSED(bl);

switch(type) {
case AccelType::QI4_SYM_CHNLQT:
Expand All @@ -158,6 +159,7 @@ size_t KleidiAI::getLhsQuantedPackedSize(AccelType type, size_t m, size_t k, siz

size_t KleidiAI::getLhsQuantedPackedOffset(AccelType type, size_t m, size_t mIdx, size_t k, size_t bl) {
MNN_ASSERT(type >= AccelType::QINT && type <= AccelType::QINT_END);
KAI_UNUSED(bl);

if(mIdx == 0) {
return 0;
Expand All @@ -181,6 +183,7 @@ void KleidiAI::runLhsPack(AccelType type, size_t m, size_t k, size_t mIdx, const

void KleidiAI::runLhsQuantPack(AccelType type, size_t m, size_t k, size_t bl, size_t mr, const void* lhs, void* lhsQuantedPacked) {
MNN_ASSERT(type >= AccelType::QINT && type <= AccelType::QINT_END);
KAI_UNUSED(bl);

switch(type) {
case AccelType::QI4_SYM_CHNLQT:
Expand All @@ -193,6 +196,8 @@ void KleidiAI::runLhsQuantPack(AccelType type, size_t m, size_t k, size_t bl, si

//Rhs
size_t KleidiAI::getRhsPackedSize(AccelType type, size_t n, size_t k, size_t bl) {
KAI_UNUSED(bl);

switch(type) {
case AccelType::QI4_SYM_CHNLQT:
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, getNr(type), getKr(type), getSr(type));
Expand All @@ -203,6 +208,8 @@ size_t KleidiAI::getRhsPackedSize(AccelType type, size_t n, size_t k, size_t bl)
}

size_t KleidiAI::getRhsPackedOffset(AccelType type, size_t nIdx, size_t k, size_t bl) {
KAI_UNUSED(bl);

if(nIdx == 0) {
return 0;
}
Expand All @@ -219,6 +226,8 @@ size_t KleidiAI::getRhsPackedOffset(AccelType type, size_t nIdx, size_t k, size_
void KleidiAI::runRhsPack(AccelType type, size_t numGroups, size_t n, size_t k, size_t bl, size_t rhsStride,
const void* rhs, const void* scale, const void* zeroPoint, const void* bias,
void* rhsPacked, bool packedQ4) {
KAI_UNUSED(bl);

switch(type) {
case AccelType::QI4_SYM_CHNLQT:
{
Expand All @@ -242,20 +251,21 @@ void KleidiAI::runRhsPack(AccelType type, size_t numGroups, size_t n, size_t k,
//Matmul
void KleidiAI::runMatmul(AccelType type, size_t m, size_t n, size_t k, size_t bl,
const void* lhsPacked, const void* rhsPacked, void* dst,
size_t dstStrideRow, size_t dstStrideCol) {
size_t dstStrideRow, size_t dstStrideCol,
const float scalarMax, const float scalarMin) {
KAI_UNUSED(bl);

switch(type) {
case AccelType::QI4_SYM_CHNLQT:
{
const float scalar_max = FLT_MAX;
const float scalar_min = -scalar_max;
if(m == 1) {
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k,
(const void *)lhsPacked, (const void *)rhsPacked, (float *)dst,
dstStrideRow, dstStrideCol, scalar_min, scalar_max);
dstStrideRow, dstStrideCol, scalarMin, scalarMax);
} else {
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k,
(const void *)lhsPacked, (const void *)rhsPacked, (float *)dst,
dstStrideRow, dstStrideCol, scalar_min, scalar_max);
dstStrideRow, dstStrideCol, scalarMin, scalarMax);
}
break;
}
Expand Down
4 changes: 3 additions & 1 deletion source/backend/cpu/arm/mnn_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,15 @@ namespace MNN {
void runRhsPack(AccelType type, size_t numGroups, size_t n, size_t k, size_t bl, size_t rhsStride,
const void* rhs, const void* scale, const void* zeroPoint, const void* bias,
void* rhsPacked, bool packedQ4);

//Dst
size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n, size_t elementSize) { return (nIdx * elementSize) + mIdx * (n * elementSize); }

//Matmul
void runMatmul(AccelType type, size_t m, size_t n, size_t k, size_t bl,
const void* lhsPacked, const void* rhsPacked, void* dst,
size_t dstStrideRow, size_t dstStrideCol);
size_t dstStrideRow, size_t dstStrideCol,
const float scalarMax, const float scalarMin);

private:
KleidiAI() {}
Expand Down
27 changes: 1 addition & 26 deletions source/backend/cpu/arm/mnn_kleidiai_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,6 @@ inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
return k / bl;
}

inline static size_t kai_num_bytes_per_block(size_t bl) {
return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_adder_rhs;
}

inline static size_t kai_rhs_packed_stride_q4c32p(size_t k, size_t nr, size_t kr, size_t bl) {
KAI_ASSUME((k % 2) == 0);
KAI_ASSUME((k % kr) == 0);
KAI_ASSUME((k % bl) == 0);
KAI_ASSUME((bl % kr) == 0);

const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);

return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias);
}

inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_t sr) {
const size_t k_internal = kai_k_roundedup(k, kr, sr);

// multiple of 2 because 2 elements in a byte
KAI_ASSERT((k_internal % 2) == 0);

return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_adder_rhs + kai_num_bytes_bias);
}

void KleidiAIUtil::transferNCHWToNC4HW4(float* src, float* dst, size_t rowNum, size_t rowSize) {
size_t blockNum = rowSize / 4;
size_t blockSize = 4 * sizeof(float);
Expand Down Expand Up @@ -247,7 +222,7 @@ void KleidiAIUtil::packQsi4cxps16s0Qs4cx(
KAI_ASSERT(params->lhs_zero_point == 1);

const size_t rhs_zero_point = params->rhs_zero_point;
const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, sr);
const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
const size_t k_internal = kai_k_roundedup(k, kr, sr);
const size_t dst_num_rows = kai_roundup(n, nr) / nr;
const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2);
Expand Down
2 changes: 1 addition & 1 deletion source/backend/cpu/compute/ConvInt8TiledExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(mAccelType, tId * vecPerThread, k, blkSize);
auto threadDst = linearDst + kai.getDstOffset(0, tId * vecPerThread, n, elementSize);
int vecNum = (tId == threadNeed - 1) ? (n - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread.
kai.runMatmul(mAccelType, m, vecNum, k, blkSize, lhsPacked, threadRhsPacked, threadDst, n * elementSize, elementSize);
kai.runMatmul(mAccelType, m, vecNum, k, blkSize, lhsPacked, threadRhsPacked, threadDst, n * elementSize, elementSize, FLT_MAX, -FLT_MAX);
};

MNN_CONCURRENCY_BEGIN(tId, threadNeed) {
Expand Down

0 comments on commit 73344cd

Please sign in to comment.