diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 02d42fbabb1d3..cb7cfbb4fb97a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -305,8 +305,36 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All if (input_idx == InputIndex::scales && packed_b_ != nullptr && MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, has_zp_input_, &mlas_backend_kernel_selector_config_)) { + // For asymmetric quantization, we require zero_points to be a constant initializer + // in order to safely use the KleidiAI packed-scales path. If zero_points is not + // constant, fall back to the non-KleidiAI path by leaving scales unpacked. + if (has_zp_input_) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor == nullptr) { + // zero_points is dynamic: do not mark scales as packed so that the + // execution falls back to the non-KleidiAI path. + return Status::OK(); + } + } + scales_are_packed_ = true; is_packed = true; + + // For KleidiAI asymmetric 4-bit path: compute BZpCorr now while scales are still accessible. + // After this PrePack returns is_packed=true, ORT may erase scales from the constant + // input table (use count drops to 0), making them unavailable in later PrePack calls. + // Zero points haven't been PrePacked yet so they are still accessible. + if (has_zp_input_ && nbits_ == 4) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + auto sptr = tensor.Data(); + auto zptr = zp_tensor->Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + } + } } #endif // MLAS_TARGET_ARM64 } else if (compute_type_ == HQNBIT_CompInt8 && nbits_ == 8) { @@ -404,6 +432,22 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); + +#if defined(MLAS_TARGET_ARM64) + // For KleidiAI asymmetric 4-bit path: compute BZpCorr during B packing. + // The fp16 specialization packs B here (with scales already converted to fp32), + // so we also compute BZpCorr now while both scales and zero_points are accessible. + if (has_zp_input_ && nbits_ == 4 && scales_fp32_ != nullptr) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + auto zptr = zp_tensor->Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + } + } +#endif // MLAS_TARGET_ARM64 + is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { bool should_pack_scale_and_zp = [&]() { diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 5e610aae86872..f1e9eda62e4f5 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -65,6 +65,9 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS { ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + + const float* BZpCorr = nullptr; ///< optional: BZpCorrection for KleidiAI asymmetric path (N * BlockCountK floats) + const float* AFloatBlkSum = nullptr; ///< optional: float-domain A block sums for KleidiAI asymmetric path (M * BlockCountK floats) }; /** diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 14aa6d18d7c9f..e861a26f188ba 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -943,6 +943,34 @@ SQ4BitGemm_CompInt8( SQ4BitGemm(BlkLen, QuantA, DataParams->PackedQuantBData, DataParams->C, RangeStartM, RangeCountM, RangeStartN, RangeCountN, K, DataParams->ldc, DataParams->Bias); + + // Apply zero-point correction for asymmetric quantization (KleidiAI path only). + // BZpCorr and AFloatBlkSum are only set when KleidiAI is active with asymmetric + // quantization (has zero points). On all other paths they remain nullptr. + // C += AFloatBlkSum * BZpCorr^T (for this tile's M/N ranges) + if (DataParams->BZpCorr != nullptr && DataParams->AFloatBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t ldc = DataParams->ldc; + const float* ABlkSum = DataParams->AFloatBlkSum + RangeStartM * BlockCountK; + const float* BCorr = DataParams->BZpCorr + RangeStartN * BlockCountK; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const auto ApplyCorrection = GetMlasPlatform().QNBitGemmDispatch->ApplyBZpCorrection; + if (ApplyCorrection) { + ApplyCorrection(ABlkSum, BCorr, C, RangeCountM, RangeCountN, BlockCountK, ldc); + } else { + // Scalar fallback + for (size_t m = 0; m < RangeCountM; ++m) { + for (size_t n = 0; n < RangeCountN; ++n) { + float corr = 0.0f; + for (size_t blk = 0; blk < BlockCountK; ++blk) { + corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk]; + } + C[m * ldc + n] += corr; + } + } + } + } return; } @@ -1186,6 +1214,7 @@ InitializeWorkspace_CompInt8( const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; const auto QuantizeA_Packed = GetMlasPlatform().QNBitGemmDispatch->QuantizeA_Packed_CompInt8; + const auto ComputeAFloatBlkSumFn = GetMlasPlatform().QNBitGemmDispatch->ComputeAFloatBlkSum; const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; @@ -1194,12 +1223,28 @@ InitializeWorkspace_CompInt8( // TODO: try parallel on BatchN * M threads because BatchN is usually 1. if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) { + // Compute KleidiAI packed A size (same as workspace size without zero points) + const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize + ? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize( + M, N, K, BlkLen, /*HasZeroPoint=*/false, SQNBIT_CompInt8, BlkBitWidth, BackendKernelSelectorConfig) + : 0; + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; const float* ARowPtr = data.A; std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr, BackendKernelSelectorConfig); + + // For asymmetric KleidiAI path, also compute float-domain A block sums + // for zero-point correction. AFloatBlkSum is stored after KleidiAI packed A. + if (data.QuantBZeroPoint != nullptr && ComputeAFloatBlkSumFn != nullptr && kleidiAIPackedASize > 0) { + // Align offset so AFloatBlkSum starts at a float-aligned address + constexpr size_t FloatAlignment = alignof(float); + const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1); + float* AFloatBlkSum = reinterpret_cast(QuantARowPtr + alignedAOffset); + ComputeAFloatBlkSumFn(ARowPtr, M, K, BlkLen, data.lda, AFloatBlkSum); + } }); } else { // TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms @@ -1482,6 +1527,45 @@ MlasQNBitGemmBatch( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + // For KleidiAI asymmetric path: set up BZpCorr and AFloatBlkSum pointers. + // BZpCorr is stored after KleidiAI packed B data. + // AFloatBlkSum is stored after KleidiAI packed A data in each per-GEMM workspace. + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + if (Variant == SQ4BitGemmVariant_CompInt8 && has_zp_input && UsePacked && + UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) { + // Compute KleidiAI packed B size (without zero point correction space) + const size_t kleidiAIPackedBSize = GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize + ? GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize( + N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BackendKernelSelectorConfig) + : 0; + // KleidiAI packed A size (workspace without AFloatBlkSum) + const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize + ? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize( + M, N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BlkBitWidth, BackendKernelSelectorConfig) + : 0; + + // Align offsets so float arrays start at float-aligned addresses + constexpr size_t FloatAlignment = alignof(float); + const size_t alignedBOffset = (kleidiAIPackedBSize + FloatAlignment - 1) & ~(FloatAlignment - 1); + const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1); + + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto* Data = const_cast*>(&DataParams[gemm_i]); + if (Data->QuantBZeroPoint == nullptr) { + continue; + } + // BZpCorr is at the end of packed B data (float-aligned) + if (kleidiAIPackedBSize > 0 && Data->PackedQuantBData != nullptr) { + Data->BZpCorr = reinterpret_cast(Data->PackedQuantBData + alignedBOffset); + } + // AFloatBlkSum is at the end of the workspace's KleidiAI packed A (float-aligned) + if (kleidiAIPackedASize > 0 && Workspace != nullptr) { + std::byte* wsBase = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + Data->AFloatBlkSum = reinterpret_cast(wsBase + alignedAOffset); + } + } + } + if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 9cb7b9aefc295..6503f0108c823 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -529,6 +529,52 @@ struct MLAS_QNBIT_GEMM_DISPATCH { QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr; + /** + * @brief Compute float-domain block sums of A for zero-point correction. + * Used when KleidiAI handles asymmetric quantization. + * + * @param A Supplies the float A matrix. + * @param CountM Number of rows of A. + * @param CountK Number of columns of A. + * @param BlkLen Number of values in a block. + * @param lda Leading dimension of A. + * @param[out] AFloatBlkSum Output: M * BlockCountK float sums. + */ + typedef void(ComputeAFloatBlkSum_Fn)( + const float* A, + size_t CountM, + size_t CountK, + size_t BlkLen, + size_t lda, + float* AFloatBlkSum + ); + + ComputeAFloatBlkSum_Fn* ComputeAFloatBlkSum = nullptr; + + /** + * @brief Apply zero-point correction: C += ABlkSum * BCorr^T + * Used after KleidiAI GEMM for asymmetric quantization. + * + * @param ABlkSum Float block sums of A, [RangeCountM, BlockCountK] row-major. + * @param BCorr BZpCorrection, [RangeCountN, BlockCountK] row-major (pre-offset). + * @param[out] C Output matrix tile (pre-offset), accumulated. + * @param RangeCountM Number of M rows in this tile. + * @param RangeCountN Number of N columns in this tile. + * @param BlockCountK Number of blocks along K. + * @param ldc Leading dimension of C. + */ + typedef void(ApplyBZpCorrection_Fn)( + const float* ABlkSum, + const float* BCorr, + float* C, + size_t RangeCountM, + size_t RangeCountN, + size_t BlockCountK, + size_t ldc + ); + + ApplyBZpCorrection_Fn* ApplyBZpCorrection = nullptr; + /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. * diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 5362b0ba7249d..ac42ced83f36c 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -20,6 +20,7 @@ Module Name: #include #include +#include #include #include @@ -62,13 +63,22 @@ QNBitGemmPackQuantBDataSize( #endif #ifdef USE_KLEIDIAI - if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) { + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; const size_t nr = ukernel.get_nr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + size_t packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + if (HasZeroPoint) { + // Align so that BZpCorrection starts at a float-aligned offset + constexpr size_t FloatAlignment = alignof(float); + packed_size = (packed_size + FloatAlignment - 1) & ~(FloatAlignment - 1); + // Additional space for BZpCorrection: N * BlockCountK floats + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + packed_size += N * BlockCountK * sizeof(float); + } + return packed_size; } else #endif { @@ -183,7 +193,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool HasZeroPoint, - const std::byte*, + const std::byte* QuantBZPBegin, PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig @@ -192,11 +202,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum( #ifndef USE_KLEIDIAI MLAS_UNREFERENCED_PARAMETER(QuantBScaleBegin); MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(QuantBZPBegin); #endif assert(BlkLen >= 16 && BlkLen % 16 == 0); #ifdef USE_KLEIDIAI - if (UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) { + if (UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; std::byte* PackedQuantBDataBegin = PackedQuantB.PackedQuantBData; @@ -204,24 +215,55 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t nr = ukernel.get_nr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_dt_bf16; + // Pack B data with KleidiAI (only when B data is provided) + if (QuantBDataBegin != nullptr) { + assert(QuantBScaleBegin != nullptr); + kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + const size_t scales_len = N * BlockCountK; + std::vector scales(scales_len); + for (size_t i = 0; i < scales_len; i++) { + uint32_t bits; + static_assert(sizeof(bits) == sizeof(QuantBScaleBegin[i]), "Unexpected float size"); + std::memcpy(&bits, &QuantBScaleBegin[i], sizeof(bits)); + scales[i] = static_cast(bits >> 16); + } - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t scales_len = N * BlockCountK; - std::vector scales(scales_len); - for (size_t i = 0; i < scales_len; i++) { - const uint32_t* i32 = reinterpret_cast(&QuantBScaleBegin[i]); - scales[i] = *i32 >> 16; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen, + reinterpret_cast(QuantBDataBegin), BlockCountK * BlkLen / 2, + nullptr, scales.data(), BlockCountK * sizeof(uint16_t), + PackedQuantBDataBegin, 0, ¶ms); } - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen, - reinterpret_cast(QuantBDataBegin), BlockCountK * BlkLen / 2, - nullptr, scales.data(), BlockCountK * sizeof(uint16_t), - PackedQuantBDataBegin, 0, ¶ms); + // Compute BZpCorrection when both scales and zero points are available. + // BZpCorr[n * BlockCountK + blk] = scale_b * (8 - zp_b) + // Note: We intentionally use fp32 scales here (not bf16-truncated) for higher precision + // in the correction term. KleidiAI internally truncates scales to bf16 for the main GEMM, + // but the correction benefits from full fp32 precision to better approximate the true result. + // This may be called separately from B packing when zero points arrive later. + if (HasZeroPoint && QuantBZPBegin != nullptr && QuantBScaleBegin != nullptr) { + const size_t kleidiai_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + // Align offset so BZpCorr starts at a float-aligned address + constexpr size_t FloatAlignment = alignof(float); + const size_t bzpcorr_offset = (kleidiai_packed_size + FloatAlignment - 1) & ~(FloatAlignment - 1); + float* BZpCorr = reinterpret_cast(PackedQuantBDataBegin + bzpcorr_offset); + + for (size_t n = 0; n < N; ++n) { + for (size_t blk = 0; blk < BlockCountK; ++blk) { + const size_t idx = n * BlockCountK + blk; + const size_t zp_byte_idx = blk / 2; + const uint8_t zp_byte = static_cast(QuantBZPBegin[n * MlasDivRoundup(BlockCountK, 2) + zp_byte_idx]); + const uint8_t zp = (blk & 1) == 0 ? (zp_byte & 0x0F) : (zp_byte >> 4); + BZpCorr[idx] = QuantBScaleBegin[idx] * (8.0f - static_cast(zp)); + } + } + } } else #endif { @@ -419,14 +461,23 @@ QNBitGemmPerGemmWorkspaceSize( case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 #ifdef USE_KLEIDIAI - if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) { + if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = (M == 1) ? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; const size_t mr = ukernel.get_mr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); - return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + size_t ws = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + if (HasZeroPoint) { + // Align so that AFloatBlkSum starts at a float-aligned offset + constexpr size_t FloatAlignment = alignof(float); + ws = (ws + FloatAlignment - 1) & ~(FloatAlignment - 1); + // Additional space for AFloatBlkSum: M * BlockCountK floats + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + ws += M * BlockCountK * sizeof(float); + } + return ws; } else #endif { @@ -471,7 +522,7 @@ QNBitGemmPerGemmWorkspaceAlignment( } // namespace bool -UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) +UseKleidiAI(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) { #ifdef USE_KLEIDIAI if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) { @@ -479,12 +530,11 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELEC } bool has_dotprod = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); - return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && !HasZp && has_dotprod; + return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && has_dotprod; #else MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); MLAS_UNREFERENCED_PARAMETER(K); MLAS_UNREFERENCED_PARAMETER(BlkLen); - MLAS_UNREFERENCED_PARAMETER(HasZp); return false; #endif } @@ -601,6 +651,8 @@ GetMlasQNBitGemmDispatchNeon( #ifdef USE_KLEIDIAI d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; + d.ComputeAFloatBlkSum = sqnbitgemm_neon::ComputeAFloatBlkSum; + d.ApplyBZpCorrection = sqnbitgemm_neon::ApplyBZpCorrection; #endif } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index 61ab7e75395e1..2c31017732a0a 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -220,10 +220,31 @@ SQ4BitGemmKernel_Packed_CompInt8( size_t ldc, const float *Bias ); + +void +ComputeAFloatBlkSum( + const float* A, + size_t CountM, + size_t CountK, + size_t BlkLen, + size_t lda, + float* AFloatBlkSum +); + +void +ApplyBZpCorrection( + const float* ABlkSum, + const float* BCorr, + float* C, + size_t RangeCountM, + size_t RangeCountN, + size_t BlockCountK, + size_t ldc +); #endif bool -UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); +UseKleidiAI(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); // // General helpers. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index a688730ebaa33..5b09495436eaf 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -134,7 +134,9 @@ QuantizeBlock( bool UsePacked_CompInt8(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) { - return UseKleidiAI(K, BlkLen, HasZp, BackendKernelSelectorConfig); + MLAS_UNREFERENCED_PARAMETER(HasZp); + // Use KleidiAI packed path for both symmetric and asymmetric (with ZP correction). + return UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig); } #ifdef USE_KLEIDIAI @@ -169,6 +171,120 @@ QuantizeA_Packed_CompInt8( kai_run_lhs_quant_pack_qai8dxp_f32(CountM, CountK, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr); } + +void +ComputeAFloatBlkSum( + const float* A, + size_t CountM, + size_t CountK, + size_t BlkLen, + size_t lda, + float* AFloatBlkSum +) +{ + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + for (size_t m = 0; m < CountM; ++m) { + const float* a_row = A + m * lda; + float* blk_sum_row = AFloatBlkSum + m * BlockCountK; + for (size_t blk = 0; blk < BlockCountK; ++blk) { + const size_t blk_start = blk * BlkLen; + const size_t blk_end = std::min(blk_start + BlkLen, CountK); + float32x4_t s0 = vdupq_n_f32(0.0f); + float32x4_t s1 = vdupq_n_f32(0.0f); + float32x4_t s2 = vdupq_n_f32(0.0f); + float32x4_t s3 = vdupq_n_f32(0.0f); + size_t k = blk_start; + for (; k + 16 <= blk_end; k += 16) { + s0 = vaddq_f32(s0, vld1q_f32(a_row + k)); + s1 = vaddq_f32(s1, vld1q_f32(a_row + k + 4)); + s2 = vaddq_f32(s2, vld1q_f32(a_row + k + 8)); + s3 = vaddq_f32(s3, vld1q_f32(a_row + k + 12)); + } + s0 = vaddq_f32(vaddq_f32(s0, s1), vaddq_f32(s2, s3)); + for (; k + 4 <= blk_end; k += 4) { + s0 = vaddq_f32(s0, vld1q_f32(a_row + k)); + } + float sum = vaddvq_f32(s0); + for (; k < blk_end; ++k) { + sum += a_row[k]; + } + blk_sum_row[blk] = sum; + } + } +} + +void +ApplyBZpCorrection( + const float* ABlkSum, + const float* BCorr, + float* C, + size_t RangeCountM, + size_t RangeCountN, + size_t BlockCountK, + size_t ldc +) +{ + // Process 4 N columns at a time. For each n-tile, iterate all M rows. + // This keeps the 4 BCorr rows in L1 cache and reuses them across all M. + size_t n = 0; + for (; n + 4 <= RangeCountN; n += 4) { + const float* b0 = BCorr + (n + 0) * BlockCountK; + const float* b1 = BCorr + (n + 1) * BlockCountK; + const float* b2 = BCorr + (n + 2) * BlockCountK; + const float* b3 = BCorr + (n + 3) * BlockCountK; + + for (size_t m = 0; m < RangeCountM; ++m) { + const float* a_row = ABlkSum + m * BlockCountK; + float32x4_t acc0 = vdupq_n_f32(0.0f); + float32x4_t acc1 = vdupq_n_f32(0.0f); + float32x4_t acc2 = vdupq_n_f32(0.0f); + float32x4_t acc3 = vdupq_n_f32(0.0f); + size_t blk = 0; + for (; blk + 4 <= BlockCountK; blk += 4) { + float32x4_t a_vec = vld1q_f32(a_row + blk); + acc0 = vfmaq_f32(acc0, a_vec, vld1q_f32(b0 + blk)); + acc1 = vfmaq_f32(acc1, a_vec, vld1q_f32(b1 + blk)); + acc2 = vfmaq_f32(acc2, a_vec, vld1q_f32(b2 + blk)); + acc3 = vfmaq_f32(acc3, a_vec, vld1q_f32(b3 + blk)); + } + // Horizontal reduce and add scalar tail + float c0 = vaddvq_f32(acc0); + float c1 = vaddvq_f32(acc1); + float c2 = vaddvq_f32(acc2); + float c3 = vaddvq_f32(acc3); + for (; blk < BlockCountK; ++blk) { + float a_val = a_row[blk]; + c0 += a_val * b0[blk]; + c1 += a_val * b1[blk]; + c2 += a_val * b2[blk]; + c3 += a_val * b3[blk]; + } + float* c_ptr = C + m * ldc + n; + c_ptr[0] += c0; + c_ptr[1] += c1; + c_ptr[2] += c2; + c_ptr[3] += c3; + } + } + + // Handle remaining N columns + for (; n < RangeCountN; ++n) { + const float* b_row = BCorr + n * BlockCountK; + for (size_t m = 0; m < RangeCountM; ++m) { + const float* a_row = ABlkSum + m * BlockCountK; + float32x4_t sum_vec = vdupq_n_f32(0.0f); + size_t blk = 0; + for (; blk + 4 <= BlockCountK; blk += 4) { + sum_vec = vfmaq_f32(sum_vec, vld1q_f32(a_row + blk), vld1q_f32(b_row + blk)); + } + float corr = vaddvq_f32(sum_vec); + for (; blk < BlockCountK; ++blk) { + corr += a_row[blk] * b_row[blk]; + } + C[m * ldc + n] += corr; + } + } +} #endif void