From 5adbcb3ee9d33765b224a847ae2c0d7db9f33348 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 17 Mar 2026 23:22:48 +0000 Subject: [PATCH 01/10] init --- .../cpu/quantization/matmul_nbits.cc | 15 ++++ onnxruntime/core/mlas/inc/mlas_qnbit.h | 3 + onnxruntime/core/mlas/lib/qnbitgemm.cpp | 68 +++++++++++++++ onnxruntime/core/mlas/lib/qnbitgemm.h | 22 +++++ .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 87 ++++++++++++++----- .../core/mlas/lib/qnbitgemm_kernel_neon.h | 13 +++ .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 35 +++++++- 7 files changed, 219 insertions(+), 24 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 02d42fbabb1d3..febd698b587da 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -308,6 +308,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All scales_are_packed_ = true; is_packed = true; } + + // For KleidiAI asymmetric path: compute BZpCorr when zero points become available. + // BZpCorr = scale_b * (8 - zp_b) for each block, stored after KleidiAI packed B data. + if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && has_zp_input_ && nbits_ == 4) { + auto zptr = tensor.Data(); + const Tensor* scales_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales_tensor); + if (scales_tensor != nullptr) { + auto sptr = scales_tensor->Data(); + // Re-invoke packing with zero points to compute BZpCorr + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + } + is_packed = false; + } #endif // MLAS_TARGET_ARM64 } else if (compute_type_ == HQNBIT_CompInt8 && nbits_ == 8) { // For 8-bit HQNBIT_CompInt8, scales are fp16 but the SQ8 packing functions expect float. diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 5e610aae86872..f1e9eda62e4f5 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -65,6 +65,9 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS { ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + + const float* BZpCorr = nullptr; ///< optional: BZpCorrection for KleidiAI asymmetric path (N * BlockCountK floats) + const float* AFloatBlkSum = nullptr; ///< optional: float-domain A block sums for KleidiAI asymmetric path (M * BlockCountK floats) }; /** diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 14aa6d18d7c9f..7fb8853b96f25 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -943,6 +943,29 @@ SQ4BitGemm_CompInt8( SQ4BitGemm(BlkLen, QuantA, DataParams->PackedQuantBData, DataParams->C, RangeStartM, RangeCountM, RangeStartN, RangeCountN, K, DataParams->ldc, DataParams->Bias); + + // Apply zero-point correction for asymmetric quantization: + // C += AFloatBlkSum * BZpCorr^T (for this tile's M/N ranges) + if (DataParams->BZpCorr != nullptr && DataParams->AFloatBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t ldc = DataParams->ldc; + const float* ABlkSum = DataParams->AFloatBlkSum + RangeStartM * BlockCountK; + const float* BCorr = DataParams->BZpCorr + RangeStartN * BlockCountK; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + // Correction GEMM: C[m,n] += sum_blk(ABlkSum[m,blk] * BCorr[n,blk]) + // ABlkSum is (M x BlockCountK) row-major, BCorr is (N x BlockCountK) row-major + // This is C += ABlkSum * BCorr^T + for (size_t m = 0; m < RangeCountM; ++m) { + for (size_t n = 0; n < RangeCountN; ++n) { + float corr = 0.0f; + for (size_t blk = 0; blk < BlockCountK; ++blk) { + corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk]; + } + C[m * ldc + n] += corr; + } + } + } return; } @@ -1186,6 +1209,7 @@ InitializeWorkspace_CompInt8( const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; const auto QuantizeA_Packed = GetMlasPlatform().QNBitGemmDispatch->QuantizeA_Packed_CompInt8; + const auto ComputeAFloatBlkSumFn = GetMlasPlatform().QNBitGemmDispatch->ComputeAFloatBlkSum; const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; @@ -1194,12 +1218,25 @@ InitializeWorkspace_CompInt8( // TODO: try parallel on BatchN * M threads because BatchN is usually 1. if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) { + // Compute KleidiAI packed A size (same as workspace size without zero points) + const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize + ? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize( + M, N, K, BlkLen, /*HasZeroPoint=*/false, SQNBIT_CompInt8, BlkBitWidth, BackendKernelSelectorConfig) + : 0; + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; const float* ARowPtr = data.A; std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr, BackendKernelSelectorConfig); + + // For asymmetric KleidiAI path, also compute float-domain A block sums + // for zero-point correction. AFloatBlkSum is stored after KleidiAI packed A. + if (data.QuantBZeroPoint != nullptr && ComputeAFloatBlkSumFn != nullptr && kleidiAIPackedASize > 0) { + float* AFloatBlkSum = reinterpret_cast(QuantARowPtr + kleidiAIPackedASize); + ComputeAFloatBlkSumFn(ARowPtr, M, K, BlkLen, data.lda, AFloatBlkSum); + } }); } else { // TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms @@ -1482,6 +1519,37 @@ MlasQNBitGemmBatch( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + // For KleidiAI asymmetric path: set up BZpCorr and AFloatBlkSum pointers. + // BZpCorr is stored after KleidiAI packed B data. + // AFloatBlkSum is stored after KleidiAI packed A data in each per-GEMM workspace. + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + if (Variant == SQ4BitGemmVariant_CompInt8 && has_zp_input && UsePacked && + UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) { + // Compute KleidiAI packed B size (without zero point correction space) + const size_t kleidiAIPackedBSize = GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize + ? GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize( + N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BackendKernelSelectorConfig) + : 0; + // KleidiAI packed A size (workspace without AFloatBlkSum) + const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize + ? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize( + M, N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BlkBitWidth, BackendKernelSelectorConfig) + : 0; + + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto* Data = const_cast*>(&DataParams[gemm_i]); + // BZpCorr is at the end of packed B data + if (kleidiAIPackedBSize > 0 && Data->PackedQuantBData != nullptr) { + Data->BZpCorr = reinterpret_cast(Data->PackedQuantBData + kleidiAIPackedBSize); + } + // AFloatBlkSum is at the end of the workspace's KleidiAI packed A + if (kleidiAIPackedASize > 0 && Workspace != nullptr) { + std::byte* wsBase = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + Data->AFloatBlkSum = reinterpret_cast(wsBase + kleidiAIPackedASize); + } + } + } + if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 9cb7b9aefc295..a09a4463ad81e 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -529,6 +529,28 @@ struct MLAS_QNBIT_GEMM_DISPATCH { QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr; + /** + * @brief Compute float-domain block sums of A for zero-point correction. + * Used when KleidiAI handles asymmetric quantization. + * + * @param A Supplies the float A matrix. + * @param CountM Number of rows of A. + * @param CountK Number of columns of A. + * @param BlkLen Number of values in a block. + * @param lda Leading dimension of A. + * @param[out] AFloatBlkSum Output: M * BlockCountK float sums. + */ + typedef void(ComputeAFloatBlkSum_Fn)( + const float* A, + size_t CountM, + size_t CountK, + size_t BlkLen, + size_t lda, + float* AFloatBlkSum + ); + + ComputeAFloatBlkSum_Fn* ComputeAFloatBlkSum = nullptr; + /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. * diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 5362b0ba7249d..e1c42b7a0c994 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -62,13 +62,19 @@ QNBitGemmPackQuantBDataSize( #endif #ifdef USE_KLEIDIAI - if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) { + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; const size_t nr = ukernel.get_nr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + size_t packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + if (HasZeroPoint) { + // Additional space for BZpCorrection: N * BlockCountK floats + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + packed_size += N * BlockCountK * sizeof(float); + } + return packed_size; } else #endif { @@ -183,7 +189,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool HasZeroPoint, - const std::byte*, + const std::byte* QuantBZPBegin, PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig @@ -192,11 +198,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum( #ifndef USE_KLEIDIAI MLAS_UNREFERENCED_PARAMETER(QuantBScaleBegin); MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(QuantBZPBegin); #endif assert(BlkLen >= 16 && BlkLen % 16 == 0); #ifdef USE_KLEIDIAI - if (UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) { + if (UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; std::byte* PackedQuantBDataBegin = PackedQuantB.PackedQuantBData; @@ -204,24 +211,46 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t nr = ukernel.get_nr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_dt_bf16; + // Pack B data with KleidiAI (only when B data is provided) + if (QuantBDataBegin != nullptr && QuantBScaleBegin != nullptr) { + kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + const size_t scales_len = N * BlockCountK; + std::vector scales(scales_len); + for (size_t i = 0; i < scales_len; i++) { + const uint32_t* i32 = reinterpret_cast(&QuantBScaleBegin[i]); + scales[i] = *i32 >> 16; + } - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t scales_len = N * BlockCountK; - std::vector scales(scales_len); - for (size_t i = 0; i < scales_len; i++) { - const uint32_t* i32 = reinterpret_cast(&QuantBScaleBegin[i]); - scales[i] = *i32 >> 16; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen, + reinterpret_cast(QuantBDataBegin), BlockCountK * BlkLen / 2, + nullptr, scales.data(), BlockCountK * sizeof(uint16_t), + PackedQuantBDataBegin, 0, ¶ms); } - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen, - reinterpret_cast(QuantBDataBegin), BlockCountK * BlkLen / 2, - nullptr, scales.data(), BlockCountK * sizeof(uint16_t), - PackedQuantBDataBegin, 0, ¶ms); + // Compute BZpCorrection when both scales and zero points are available. + // BZpCorr[n * BlockCountK + blk] = scale_b * (8 - zp_b) + // This may be called separately from B packing when zero points arrive later. + if (HasZeroPoint && QuantBZPBegin != nullptr && QuantBScaleBegin != nullptr) { + const size_t kleidiai_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + float* BZpCorr = reinterpret_cast(PackedQuantBDataBegin + kleidiai_packed_size); + + for (size_t n = 0; n < N; ++n) { + for (size_t blk = 0; blk < BlockCountK; ++blk) { + const size_t idx = n * BlockCountK + blk; + const size_t zp_byte_idx = blk / 2; + const uint8_t zp_byte = static_cast(QuantBZPBegin[n * MlasDivRoundup(BlockCountK, 2) + zp_byte_idx]); + const uint8_t zp = (blk & 1) == 0 ? (zp_byte & 0x0F) : (zp_byte >> 4); + BZpCorr[idx] = QuantBScaleBegin[idx] * (8.0f - static_cast(zp)); + } + } + } } else #endif { @@ -419,14 +448,20 @@ QNBitGemmPerGemmWorkspaceSize( case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 #ifdef USE_KLEIDIAI - if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) { + if (BlkBitWidth == 4 && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = (M == 1) ? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; const size_t mr = ukernel.get_mr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); - return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + size_t ws = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + if (HasZeroPoint) { + // Additional space for AFloatBlkSum: M * BlockCountK floats + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + ws += M * BlockCountK * sizeof(float); + } + return ws; } else #endif { @@ -471,7 +506,7 @@ QNBitGemmPerGemmWorkspaceAlignment( } // namespace bool -UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) +UseKleidiAIBase(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) { #ifdef USE_KLEIDIAI if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) { @@ -479,16 +514,21 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELEC } bool has_dotprod = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); - return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && !HasZp && has_dotprod; + return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && has_dotprod; #else MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); MLAS_UNREFERENCED_PARAMETER(K); MLAS_UNREFERENCED_PARAMETER(BlkLen); - MLAS_UNREFERENCED_PARAMETER(HasZp); return false; #endif } +bool +UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) +{ + return !HasZp && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig); +} + template size_t SQ8BitGemmKernel_BlkSum_CompInt8( @@ -601,6 +641,7 @@ GetMlasQNBitGemmDispatchNeon( #ifdef USE_KLEIDIAI d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; + d.ComputeAFloatBlkSum = sqnbitgemm_neon::ComputeAFloatBlkSum; #endif } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index 61ab7e75395e1..e7074f6fd08b3 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -220,8 +220,21 @@ SQ4BitGemmKernel_Packed_CompInt8( size_t ldc, const float *Bias ); + +void +ComputeAFloatBlkSum( + const float* A, + size_t CountM, + size_t CountK, + size_t BlkLen, + size_t lda, + float* AFloatBlkSum +); #endif +bool +UseKleidiAIBase(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); + bool UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index a688730ebaa33..4c682e27da87d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -134,7 +134,9 @@ QuantizeBlock( bool UsePacked_CompInt8(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) { - return UseKleidiAI(K, BlkLen, HasZp, BackendKernelSelectorConfig); + MLAS_UNREFERENCED_PARAMETER(HasZp); + // Use KleidiAI packed path for both symmetric and asymmetric (with ZP correction). + return UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig); } #ifdef USE_KLEIDIAI @@ -169,6 +171,37 @@ QuantizeA_Packed_CompInt8( kai_run_lhs_quant_pack_qai8dxp_f32(CountM, CountK, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr); } + +void +ComputeAFloatBlkSum( + const float* A, + size_t CountM, + size_t CountK, + size_t BlkLen, + size_t lda, + float* AFloatBlkSum +) +{ + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + for (size_t m = 0; m < CountM; ++m) { + const float* a_row = A + m * lda; + float* blk_sum_row = AFloatBlkSum + m * BlockCountK; + for (size_t blk = 0; blk < BlockCountK; ++blk) { + const size_t blk_start = blk * BlkLen; + const size_t blk_end = std::min(blk_start + BlkLen, CountK); + float32x4_t sum_vec = vdupq_n_f32(0.0f); + size_t k = blk_start; + for (; k + 4 <= blk_end; k += 4) { + sum_vec = vaddq_f32(sum_vec, vld1q_f32(a_row + k)); + } + float sum = vaddvq_f32(sum_vec); + for (; k < blk_end; ++k) { + sum += a_row[k]; + } + blk_sum_row[blk] = sum; + } + } +} #endif void From 35e6c85cfc6f1a4e843ab6f297d8d52400280b58 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 17 Mar 2026 23:28:42 +0000 Subject: [PATCH 02/10] optimized correction --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 32 ++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 7fb8853b96f25..0d2ef806f3efe 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -953,18 +953,28 @@ SQ4BitGemm_CompInt8( const float* BCorr = DataParams->BZpCorr + RangeStartN * BlockCountK; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - // Correction GEMM: C[m,n] += sum_blk(ABlkSum[m,blk] * BCorr[n,blk]) - // ABlkSum is (M x BlockCountK) row-major, BCorr is (N x BlockCountK) row-major - // This is C += ABlkSum * BCorr^T - for (size_t m = 0; m < RangeCountM; ++m) { - for (size_t n = 0; n < RangeCountN; ++n) { - float corr = 0.0f; - for (size_t blk = 0; blk < BlockCountK; ++blk) { - corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk]; - } - C[m * ldc + n] += corr; - } + // Correction GEMM: C += ABlkSum * BCorr^T using platform-optimized SGEMM +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64) + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + ABlkSum, BCorr, C, BlockCountK, RowsRemaining, RangeCountN, BlockCountK, ldc, 1.f, false + ); + C += ldc * RowsHandled; + ABlkSum += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; } +#else + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd( + ABlkSum, BCorr, C, BlockCountK, RowsRemaining, RangeCountN, BlockCountK, ldc, 1.f + ); + C += ldc * RowsHandled; + ABlkSum += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } +#endif } return; } From f602a51d1a66cf8e3d04ba1209bb70ce6a3b79bf Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 18 Mar 2026 06:33:05 +0000 Subject: [PATCH 03/10] use naive loop --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 33 +++++++++---------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 0d2ef806f3efe..29638ff117c22 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -953,28 +953,19 @@ SQ4BitGemm_CompInt8( const float* BCorr = DataParams->BZpCorr + RangeStartN * BlockCountK; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - // Correction GEMM: C += ABlkSum * BCorr^T using platform-optimized SGEMM -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64) - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - ABlkSum, BCorr, C, BlockCountK, RowsRemaining, RangeCountN, BlockCountK, ldc, 1.f, false - ); - C += ldc * RowsHandled; - ABlkSum += BlockCountK * RowsHandled; - RowsRemaining -= RowsHandled; - } -#else - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - auto RowsHandled = MlasSgemmKernelAdd( - ABlkSum, BCorr, C, BlockCountK, RowsRemaining, RangeCountN, BlockCountK, ldc, 1.f - ); - C += ldc * RowsHandled; - ABlkSum += BlockCountK * RowsHandled; - RowsRemaining -= RowsHandled; + // Correction GEMM: C[m,n] += sum_blk(ABlkSum[m,blk] * BCorr[n,blk]) + // ABlkSum is (M x BlockCountK) row-major, BCorr is (N x BlockCountK) row-major + // Note: cannot use MlasSgemmKernelAdd here because it expects B in panel-packed + // format, but BZpCorr is in simple [N, BlockCountK] row-major layout. + for (size_t m = 0; m < RangeCountM; ++m) { + for (size_t n = 0; n < RangeCountN; ++n) { + float corr = 0.0f; + for (size_t blk = 0; blk < BlockCountK; ++blk) { + corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk]; + } + C[m * ldc + n] += corr; + } } -#endif } return; } From afc0ae2097723a34b6a4535cb32b399b681c2b37 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 18 Mar 2026 06:54:42 +0000 Subject: [PATCH 04/10] neon correction --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 22 +++++++------- onnxruntime/core/mlas/lib/qnbitgemm.h | 24 +++++++++++++++ .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 1 + .../core/mlas/lib/qnbitgemm_kernel_neon.h | 11 +++++++ .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 30 +++++++++++++++++++ 5 files changed, 78 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 29638ff117c22..8468f5aa048f5 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -953,17 +953,19 @@ SQ4BitGemm_CompInt8( const float* BCorr = DataParams->BZpCorr + RangeStartN * BlockCountK; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - // Correction GEMM: C[m,n] += sum_blk(ABlkSum[m,blk] * BCorr[n,blk]) - // ABlkSum is (M x BlockCountK) row-major, BCorr is (N x BlockCountK) row-major - // Note: cannot use MlasSgemmKernelAdd here because it expects B in panel-packed - // format, but BZpCorr is in simple [N, BlockCountK] row-major layout. - for (size_t m = 0; m < RangeCountM; ++m) { - for (size_t n = 0; n < RangeCountN; ++n) { - float corr = 0.0f; - for (size_t blk = 0; blk < BlockCountK; ++blk) { - corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk]; + const auto ApplyCorrection = GetMlasPlatform().QNBitGemmDispatch->ApplyBZpCorrection; + if (ApplyCorrection) { + ApplyCorrection(ABlkSum, BCorr, C, RangeCountM, RangeCountN, BlockCountK, ldc); + } else { + // Scalar fallback + for (size_t m = 0; m < RangeCountM; ++m) { + for (size_t n = 0; n < RangeCountN; ++n) { + float corr = 0.0f; + for (size_t blk = 0; blk < BlockCountK; ++blk) { + corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk]; + } + C[m * ldc + n] += corr; } - C[m * ldc + n] += corr; } } } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index a09a4463ad81e..6503f0108c823 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -551,6 +551,30 @@ struct MLAS_QNBIT_GEMM_DISPATCH { ComputeAFloatBlkSum_Fn* ComputeAFloatBlkSum = nullptr; + /** + * @brief Apply zero-point correction: C += ABlkSum * BCorr^T + * Used after KleidiAI GEMM for asymmetric quantization. + * + * @param ABlkSum Float block sums of A, [RangeCountM, BlockCountK] row-major. + * @param BCorr BZpCorrection, [RangeCountN, BlockCountK] row-major (pre-offset). + * @param[out] C Output matrix tile (pre-offset), accumulated. + * @param RangeCountM Number of M rows in this tile. + * @param RangeCountN Number of N columns in this tile. + * @param BlockCountK Number of blocks along K. + * @param ldc Leading dimension of C. + */ + typedef void(ApplyBZpCorrection_Fn)( + const float* ABlkSum, + const float* BCorr, + float* C, + size_t RangeCountM, + size_t RangeCountN, + size_t BlockCountK, + size_t ldc + ); + + ApplyBZpCorrection_Fn* ApplyBZpCorrection = nullptr; + /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. * diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index e1c42b7a0c994..9037fc18258e9 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -642,6 +642,7 @@ GetMlasQNBitGemmDispatchNeon( d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; d.ComputeAFloatBlkSum = sqnbitgemm_neon::ComputeAFloatBlkSum; + d.ApplyBZpCorrection = sqnbitgemm_neon::ApplyBZpCorrection; #endif } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index e7074f6fd08b3..3451b659dfb84 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -230,6 +230,17 @@ ComputeAFloatBlkSum( size_t lda, float* AFloatBlkSum ); + +void +ApplyBZpCorrection( + const float* ABlkSum, + const float* BCorr, + float* C, + size_t RangeCountM, + size_t RangeCountN, + size_t BlockCountK, + size_t ldc +); #endif bool diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index 4c682e27da87d..f85b21b9ce1e1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -202,6 +202,36 @@ ComputeAFloatBlkSum( } } } + +void +ApplyBZpCorrection( + const float* ABlkSum, + const float* BCorr, + float* C, + size_t RangeCountM, + size_t RangeCountN, + size_t BlockCountK, + size_t ldc +) +{ + for (size_t m = 0; m < RangeCountM; ++m) { + const float* a_row = ABlkSum + m * BlockCountK; + float* c_row = C + m * ldc; + for (size_t n = 0; n < RangeCountN; ++n) { + const float* b_row = BCorr + n * BlockCountK; + float32x4_t sum_vec = vdupq_n_f32(0.0f); + size_t blk = 0; + for (; blk + 4 <= BlockCountK; blk += 4) { + sum_vec = vfmaq_f32(sum_vec, vld1q_f32(a_row + blk), vld1q_f32(b_row + blk)); + } + float corr = vaddvq_f32(sum_vec); + for (; blk < BlockCountK; ++blk) { + corr += a_row[blk] * b_row[blk]; + } + c_row[n] += corr; + } + } +} #endif void From 2d74cd754558f8541bd38a641d8f7f97046e0e71 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 18 Mar 2026 17:39:35 +0000 Subject: [PATCH 05/10] fix correction --- .../cpu/quantization/matmul_nbits.cc | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index febd698b587da..6093a8d8df16e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -307,21 +307,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All has_zp_input_, &mlas_backend_kernel_selector_config_)) { scales_are_packed_ = true; is_packed = true; - } - // For KleidiAI asymmetric path: compute BZpCorr when zero points become available. - // BZpCorr = scale_b * (8 - zp_b) for each block, stored after KleidiAI packed B data. - if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && has_zp_input_ && nbits_ == 4) { - auto zptr = tensor.Data(); - const Tensor* scales_tensor = nullptr; - OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales_tensor); - if (scales_tensor != nullptr) { - auto sptr = scales_tensor->Data(); - // Re-invoke packing with zero points to compute BZpCorr - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + // For KleidiAI asymmetric path: compute BZpCorr now while scales are still accessible. + // After this PrePack returns is_packed=true, ORT may erase scales from the constant + // input table (use count drops to 0), making them unavailable in later PrePack calls. + // Zero points haven't been PrePacked yet so they are still accessible. + if (has_zp_input_ && nbits_ == 4) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + auto sptr = tensor.Data(); + auto zptr = zp_tensor->Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + } } - is_packed = false; } #endif // MLAS_TARGET_ARM64 } else if (compute_type_ == HQNBIT_CompInt8 && nbits_ == 8) { From 30fe481bb566929f1d98bc8a90b6dd3760208fc5 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 18 Mar 2026 19:41:09 +0000 Subject: [PATCH 06/10] more opt --- .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 71 ++++++++++++++++--- 1 file changed, 62 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index f85b21b9ce1e1..d1bdfce11dc71 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -189,12 +189,22 @@ ComputeAFloatBlkSum( for (size_t blk = 0; blk < BlockCountK; ++blk) { const size_t blk_start = blk * BlkLen; const size_t blk_end = std::min(blk_start + BlkLen, CountK); - float32x4_t sum_vec = vdupq_n_f32(0.0f); + float32x4_t s0 = vdupq_n_f32(0.0f); + float32x4_t s1 = vdupq_n_f32(0.0f); + float32x4_t s2 = vdupq_n_f32(0.0f); + float32x4_t s3 = vdupq_n_f32(0.0f); size_t k = blk_start; + for (; k + 16 <= blk_end; k += 16) { + s0 = vaddq_f32(s0, vld1q_f32(a_row + k)); + s1 = vaddq_f32(s1, vld1q_f32(a_row + k + 4)); + s2 = vaddq_f32(s2, vld1q_f32(a_row + k + 8)); + s3 = vaddq_f32(s3, vld1q_f32(a_row + k + 12)); + } + s0 = vaddq_f32(vaddq_f32(s0, s1), vaddq_f32(s2, s3)); for (; k + 4 <= blk_end; k += 4) { - sum_vec = vaddq_f32(sum_vec, vld1q_f32(a_row + k)); + s0 = vaddq_f32(s0, vld1q_f32(a_row + k)); } - float sum = vaddvq_f32(sum_vec); + float sum = vaddvq_f32(s0); for (; k < blk_end; ++k) { sum += a_row[k]; } @@ -214,11 +224,54 @@ ApplyBZpCorrection( size_t ldc ) { - for (size_t m = 0; m < RangeCountM; ++m) { - const float* a_row = ABlkSum + m * BlockCountK; - float* c_row = C + m * ldc; - for (size_t n = 0; n < RangeCountN; ++n) { - const float* b_row = BCorr + n * BlockCountK; + // Process 4 N columns at a time. For each n-tile, iterate all M rows. + // This keeps the 4 BCorr rows in L1 cache and reuses them across all M. + size_t n = 0; + for (; n + 4 <= RangeCountN; n += 4) { + const float* b0 = BCorr + (n + 0) * BlockCountK; + const float* b1 = BCorr + (n + 1) * BlockCountK; + const float* b2 = BCorr + (n + 2) * BlockCountK; + const float* b3 = BCorr + (n + 3) * BlockCountK; + + for (size_t m = 0; m < RangeCountM; ++m) { + const float* a_row = ABlkSum + m * BlockCountK; + float32x4_t acc0 = vdupq_n_f32(0.0f); + float32x4_t acc1 = vdupq_n_f32(0.0f); + float32x4_t acc2 = vdupq_n_f32(0.0f); + float32x4_t acc3 = vdupq_n_f32(0.0f); + size_t blk = 0; + for (; blk + 4 <= BlockCountK; blk += 4) { + float32x4_t a_vec = vld1q_f32(a_row + blk); + acc0 = vfmaq_f32(acc0, a_vec, vld1q_f32(b0 + blk)); + acc1 = vfmaq_f32(acc1, a_vec, vld1q_f32(b1 + blk)); + acc2 = vfmaq_f32(acc2, a_vec, vld1q_f32(b2 + blk)); + acc3 = vfmaq_f32(acc3, a_vec, vld1q_f32(b3 + blk)); + } + // Horizontal reduce and add scalar tail + float c0 = vaddvq_f32(acc0); + float c1 = vaddvq_f32(acc1); + float c2 = vaddvq_f32(acc2); + float c3 = vaddvq_f32(acc3); + for (; blk < BlockCountK; ++blk) { + float a_val = a_row[blk]; + c0 += a_val * b0[blk]; + c1 += a_val * b1[blk]; + c2 += a_val * b2[blk]; + c3 += a_val * b3[blk]; + } + float* c_ptr = C + m * ldc + n; + c_ptr[0] += c0; + c_ptr[1] += c1; + c_ptr[2] += c2; + c_ptr[3] += c3; + } + } + + // Handle remaining N columns + for (; n < RangeCountN; ++n) { + const float* b_row = BCorr + n * BlockCountK; + for (size_t m = 0; m < RangeCountM; ++m) { + const float* a_row = ABlkSum + m * BlockCountK; float32x4_t sum_vec = vdupq_n_f32(0.0f); size_t blk = 0; for (; blk + 4 <= BlockCountK; blk += 4) { @@ -228,7 +281,7 @@ ApplyBZpCorrection( for (; blk < BlockCountK; ++blk) { corr += a_row[blk] * b_row[blk]; } - c_row[n] += corr; + C[m * ldc + n] += corr; } } } From f417a99bf8b717224f99efabc33e60e8fa2b05da Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 18 Mar 2026 23:43:08 +0000 Subject: [PATCH 07/10] Fix float alignment and per-entry zero point gating - Align BZpCorr and AFloatBlkSum offsets to alignof(float) to avoid undefined behavior from misaligned float pointer dereferences. KleidiAI packed sizes may not be multiples of 4. - Gate BZpCorr/AFloatBlkSum pointer setup on per-entry Data->QuantBZeroPoint != nullptr (skip entries without zero points). - Assert QuantBScaleBegin != nullptr when QuantBDataBegin is provided in KleidiAI packing path (scales are required for RHS packing). --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 21 ++++++++++++++----- .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 14 +++++++++++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 8468f5aa048f5..7ab94e8d800d2 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -1237,7 +1237,10 @@ InitializeWorkspace_CompInt8( // For asymmetric KleidiAI path, also compute float-domain A block sums // for zero-point correction. AFloatBlkSum is stored after KleidiAI packed A. if (data.QuantBZeroPoint != nullptr && ComputeAFloatBlkSumFn != nullptr && kleidiAIPackedASize > 0) { - float* AFloatBlkSum = reinterpret_cast(QuantARowPtr + kleidiAIPackedASize); + // Align offset so AFloatBlkSum starts at a float-aligned address + constexpr size_t FloatAlignment = alignof(float); + const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1); + float* AFloatBlkSum = reinterpret_cast(QuantARowPtr + alignedAOffset); ComputeAFloatBlkSumFn(ARowPtr, M, K, BlkLen, data.lda, AFloatBlkSum); } }); @@ -1539,16 +1542,24 @@ MlasQNBitGemmBatch( M, N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BlkBitWidth, BackendKernelSelectorConfig) : 0; + // Align offsets so float arrays start at float-aligned addresses + constexpr size_t FloatAlignment = alignof(float); + const size_t alignedBOffset = (kleidiAIPackedBSize + FloatAlignment - 1) & ~(FloatAlignment - 1); + const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1); + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { auto* Data = const_cast*>(&DataParams[gemm_i]); - // BZpCorr is at the end of packed B data + if (Data->QuantBZeroPoint == nullptr) { + continue; + } + // BZpCorr is at the end of packed B data (float-aligned) if (kleidiAIPackedBSize > 0 && Data->PackedQuantBData != nullptr) { - Data->BZpCorr = reinterpret_cast(Data->PackedQuantBData + kleidiAIPackedBSize); + Data->BZpCorr = reinterpret_cast(Data->PackedQuantBData + alignedBOffset); } - // AFloatBlkSum is at the end of the workspace's KleidiAI packed A + // AFloatBlkSum is at the end of the workspace's KleidiAI packed A (float-aligned) if (kleidiAIPackedASize > 0 && Workspace != nullptr) { std::byte* wsBase = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - Data->AFloatBlkSum = reinterpret_cast(wsBase + kleidiAIPackedASize); + Data->AFloatBlkSum = reinterpret_cast(wsBase + alignedAOffset); } } } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 9037fc18258e9..501461ad455c4 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -70,6 +70,9 @@ QNBitGemmPackQuantBDataSize( const size_t sr = ukernel.get_sr(); size_t packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); if (HasZeroPoint) { + // Align so that BZpCorrection starts at a float-aligned offset + constexpr size_t FloatAlignment = alignof(float); + packed_size = (packed_size + FloatAlignment - 1) & ~(FloatAlignment - 1); // Additional space for BZpCorrection: N * BlockCountK floats const size_t BlockCountK = MlasDivRoundup(K, BlkLen); packed_size += N * BlockCountK * sizeof(float); @@ -214,7 +217,8 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); // Pack B data with KleidiAI (only when B data is provided) - if (QuantBDataBegin != nullptr && QuantBScaleBegin != nullptr) { + if (QuantBDataBegin != nullptr) { + assert(QuantBScaleBegin != nullptr); kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; @@ -239,7 +243,10 @@ SQ4BitGemmPackQuantBDataAndBlkSum( if (HasZeroPoint && QuantBZPBegin != nullptr && QuantBScaleBegin != nullptr) { const size_t kleidiai_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( N, K, nr, kr, sr, BlkLen, kai_dt_bf16); - float* BZpCorr = reinterpret_cast(PackedQuantBDataBegin + kleidiai_packed_size); + // Align offset so BZpCorr starts at a float-aligned address + constexpr size_t FloatAlignment = alignof(float); + const size_t bzpcorr_offset = (kleidiai_packed_size + FloatAlignment - 1) & ~(FloatAlignment - 1); + float* BZpCorr = reinterpret_cast(PackedQuantBDataBegin + bzpcorr_offset); for (size_t n = 0; n < N; ++n) { for (size_t blk = 0; blk < BlockCountK; ++blk) { @@ -457,6 +464,9 @@ QNBitGemmPerGemmWorkspaceSize( const size_t sr = ukernel.get_sr(); size_t ws = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); if (HasZeroPoint) { + // Align so that AFloatBlkSum starts at a float-aligned offset + constexpr size_t FloatAlignment = alignof(float); + ws = (ws + FloatAlignment - 1) & ~(FloatAlignment - 1); // Additional space for AFloatBlkSum: M * BlockCountK floats const size_t BlockCountK = MlasDivRoundup(K, BlkLen); ws += M * BlockCountK * sizeof(float); From a25adc7d9ea6fd4a2861c6520a9c8b521bb1c5cf Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 19 Mar 2026 00:49:15 +0000 Subject: [PATCH 08/10] Address review: strict aliasing, dynamic zero_points guard, doc - Fix strict aliasing UB in float->bf16 scale conversion: use std::memcpy instead of reinterpret_cast. - Gate KleidiAI packed-scales path on constant zero_points: if zero_points is dynamic (TryGetConstantInput returns nullptr), fall back to non-KleidiAI path instead of marking scales as packed with uninitialized BZpCorr. - Document intentional use of fp32 scales (not bf16-truncated) in BZpCorr computation for higher correction precision. --- .../contrib_ops/cpu/quantization/matmul_nbits.cc | 15 ++++++++++++++- .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 10 ++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 6093a8d8df16e..f2bacd2be67bb 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -305,10 +305,23 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All if (input_idx == InputIndex::scales && packed_b_ != nullptr && MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, has_zp_input_, &mlas_backend_kernel_selector_config_)) { + // For asymmetric quantization, we require zero_points to be a constant initializer + // in order to safely use the KleidiAI packed-scales path. If zero_points is not + // constant, fall back to the non-KleidiAI path by leaving scales unpacked. + if (has_zp_input_) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor == nullptr) { + // zero_points is dynamic: do not mark scales as packed so that the + // execution falls back to the non-KleidiAI path. + return Status::OK(); + } + } + scales_are_packed_ = true; is_packed = true; - // For KleidiAI asymmetric path: compute BZpCorr now while scales are still accessible. + // For KleidiAI asymmetric 4-bit path: compute BZpCorr now while scales are still accessible. // After this PrePack returns is_packed=true, ORT may erase scales from the constant // input table (use count drops to 0), making them unavailable in later PrePack calls. // Zero points haven't been PrePacked yet so they are still accessible. diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 501461ad455c4..7e17aed0f1c0e 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -20,6 +20,7 @@ Module Name: #include #include +#include #include #include @@ -227,8 +228,10 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t scales_len = N * BlockCountK; std::vector scales(scales_len); for (size_t i = 0; i < scales_len; i++) { - const uint32_t* i32 = reinterpret_cast(&QuantBScaleBegin[i]); - scales[i] = *i32 >> 16; + uint32_t bits; + static_assert(sizeof(bits) == sizeof(QuantBScaleBegin[i]), "Unexpected float size"); + std::memcpy(&bits, &QuantBScaleBegin[i], sizeof(bits)); + scales[i] = static_cast(bits >> 16); } kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen, @@ -239,6 +242,9 @@ SQ4BitGemmPackQuantBDataAndBlkSum( // Compute BZpCorrection when both scales and zero points are available. // BZpCorr[n * BlockCountK + blk] = scale_b * (8 - zp_b) + // Note: We intentionally use fp32 scales here (not bf16-truncated) for higher precision + // in the correction term. KleidiAI internally truncates scales to bf16 for the main GEMM, + // but the correction benefits from full fp32 precision to better approximate the true result. // This may be called separately from B packing when zero points arrive later. if (HasZeroPoint && QuantBZPBegin != nullptr && QuantBScaleBegin != nullptr) { const size_t kleidiai_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( From 6684be5ec0bc5d8ae49b65687eeb391ec20e2a6c Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 19 Mar 2026 08:48:59 +0000 Subject: [PATCH 09/10] Fix fp16 path missing BZpCorr on macOS ARM The MLFloat16 PrePack specialization (used on macOS ARM where MLAS_F16VEC_INTRINSICS_SUPPORTED is not defined) was missing the KleidiAI asymmetric BZpCorr computation. B packing passed nullptr for zero_points, leaving BZpCorr uninitialized. The correction GEMM then added garbage values. Fix: compute BZpCorr during input_idx==B in the fp16 PrePack path, where both scales (already converted to fp32) and zero_points (not yet PrePacked) are accessible. --- .../contrib_ops/cpu/quantization/matmul_nbits.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index f2bacd2be67bb..cb7cfbb4fb97a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -432,6 +432,22 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); + +#if defined(MLAS_TARGET_ARM64) + // For KleidiAI asymmetric 4-bit path: compute BZpCorr during B packing. + // The fp16 specialization packs B here (with scales already converted to fp32), + // so we also compute BZpCorr now while both scales and zero_points are accessible. + if (has_zp_input_ && nbits_ == 4 && scales_fp32_ != nullptr) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + auto zptr = zp_tensor->Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + } + } +#endif // MLAS_TARGET_ARM64 + is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { bool should_pack_scale_and_zp = [&]() { From 75c0b38d9067aeb53478e446296bd9b6e1445dca Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 19 Mar 2026 21:55:35 +0000 Subject: [PATCH 10/10] Remove dead UseKleidiAI(HasZp) overload, rename UseKleidiAIBase - UseKleidiAI(K, BlkLen, HasZp, ...) had zero callers after the asymmetric support was added. Removed it. - Renamed UseKleidiAIBase() back to UseKleidiAI() since it is now the only variant. It checks KleidiAI eligibility (BlkLen, K, dotprod) without zero-point gating. - Added clarifying comment that BZpCorr/AFloatBlkSum correction condition is only true for the KleidiAI asymmetric path. --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 4 +++- .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 14 ++++---------- onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h | 5 +---- .../core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 2 +- 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 7ab94e8d800d2..e861a26f188ba 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -944,7 +944,9 @@ SQ4BitGemm_CompInt8( DataParams->C, RangeStartM, RangeCountM, RangeStartN, RangeCountN, K, DataParams->ldc, DataParams->Bias); - // Apply zero-point correction for asymmetric quantization: + // Apply zero-point correction for asymmetric quantization (KleidiAI path only). + // BZpCorr and AFloatBlkSum are only set when KleidiAI is active with asymmetric + // quantization (has zero points). On all other paths they remain nullptr. // C += AFloatBlkSum * BZpCorr^T (for this tile's M/N ranges) if (DataParams->BZpCorr != nullptr && DataParams->AFloatBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 7e17aed0f1c0e..ac42ced83f36c 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -63,7 +63,7 @@ QNBitGemmPackQuantBDataSize( #endif #ifdef USE_KLEIDIAI - if (ComputeType == SQNBIT_CompInt8 && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) { + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; const size_t nr = ukernel.get_nr(); @@ -207,7 +207,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( assert(BlkLen >= 16 && BlkLen % 16 == 0); #ifdef USE_KLEIDIAI - if (UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) { + if (UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; std::byte* PackedQuantBDataBegin = PackedQuantB.PackedQuantBData; @@ -461,7 +461,7 @@ QNBitGemmPerGemmWorkspaceSize( case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 #ifdef USE_KLEIDIAI - if (BlkBitWidth == 4 && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) { + if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig)) { const auto& k = (M == 1) ? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); const auto& ukernel = k.ukernel; @@ -522,7 +522,7 @@ QNBitGemmPerGemmWorkspaceAlignment( } // namespace bool -UseKleidiAIBase(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) +UseKleidiAI(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) { #ifdef USE_KLEIDIAI if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) { @@ -539,12 +539,6 @@ UseKleidiAIBase(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONF #endif } -bool -UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig) -{ - return !HasZp && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig); -} - template size_t SQ8BitGemmKernel_BlkSum_CompInt8( diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index 3451b659dfb84..2c31017732a0a 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -244,10 +244,7 @@ ApplyBZpCorrection( #endif bool -UseKleidiAIBase(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); - -bool -UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); +UseKleidiAI(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); // // General helpers. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index d1bdfce11dc71..5b09495436eaf 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -136,7 +136,7 @@ UsePacked_CompInt8(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNE { MLAS_UNREFERENCED_PARAMETER(HasZp); // Use KleidiAI packed path for both symmetric and asymmetric (with ZP correction). - return UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig); + return UseKleidiAI(K, BlkLen, BackendKernelSelectorConfig); } #ifdef USE_KLEIDIAI