Enable KleidiAI for asymmetric 4-bit MatMulNBits on ARM64#27751
Enable KleidiAI for asymmetric 4-bit MatMulNBits on ARM64#27751
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the ARM64 KleidiAI-accelerated MatMulNBits (SQNBIT int8 compute, int4 weights) path to support asymmetric 4-bit quantization by running KleidiAI in a symmetric mode and applying a post-GEMM floating-point zero-point correction.
Changes:
- Introduces
UseKleidiAIBase()to determine KleidiAI eligibility independent ofHasZp, enabling KleidiAI for asymmetric int4 weights. - Adds packing-time precomputation/storage of
BZpCorrand runtime computation ofAFloatBlkSum, plus a NEON correction kernelC += ABlkSum * BCorr^T. - Expands per-GEMM workspace and dispatch plumbing to carry/apply the correction data.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | Adds NEON kernels to compute A block sums and apply the asymmetric ZP correction; updates packed-path selection to use UseKleidiAIBase. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h | Declares new helper functions (ComputeAFloatBlkSum, ApplyBZpCorrection) and UseKleidiAIBase. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp | Updates KleidiAI packing-size/workspace sizing and computes/stores BZpCorr adjacent to KleidiAI packed-B. |
| onnxruntime/core/mlas/lib/qnbitgemm.h | Extends MLAS_QNBIT_GEMM_DISPATCH with function pointers for A-sum computation and correction application. |
| onnxruntime/core/mlas/lib/qnbitgemm.cpp | Wires correction into the KleidiAI packed SQ4 path and sets up runtime pointers into packed-B / per-GEMM workspace. |
| onnxruntime/core/mlas/inc/mlas_qnbit.h | Extends MLAS_QNBIT_GEMM_DATA_PARAMS with BZpCorr and AFloatBlkSum pointers. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Computes BZpCorr during scales PrePack for ARM64 KleidiAI packed path to avoid constant-input erasure issues. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
- Align BZpCorr and AFloatBlkSum offsets to alignof(float) to avoid undefined behavior from misaligned float pointer dereferences. KleidiAI packed sizes may not be multiples of 4. - Gate BZpCorr/AFloatBlkSum pointer setup on per-entry Data->QuantBZeroPoint != nullptr (skip entries without zero points). - Assert QuantBScaleBegin != nullptr when QuantBDataBegin is provided in KleidiAI packing path (scales are required for RHS packing).
There was a problem hiding this comment.
Pull request overview
Extends the ARM64 KleidiAI-accelerated MatMulNBits (SQ4 int4) path to support asymmetric 4-bit quantization by running KleidiAI in a symmetric configuration and applying an additional float-domain zero-point correction.
Changes:
- Enables KleidiAI packed path selection for int4 even when RHS zero-points are present (via
UseKleidiAIBase). - Adds computation of
AFloatBlkSumand appliesC += ABlkSum * BZpCorr^Tafter KleidiAI GEMM. - Extends B packing and per-GEMM workspace sizing/layout to store and use
BZpCorrandAFloatBlkSum.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | Adds NEON implementations for ComputeAFloatBlkSum and ApplyBZpCorrection; updates packed-path selection to use UseKleidiAIBase. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h | Declares new helper functions and UseKleidiAIBase. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp | Updates packing size logic and B packing to optionally compute/store BZpCorr; introduces UseKleidiAIBase. |
| onnxruntime/core/mlas/lib/qnbitgemm.h | Extends dispatch struct with function pointers for ComputeAFloatBlkSum and ApplyBZpCorrection. |
| onnxruntime/core/mlas/lib/qnbitgemm.cpp | Wires up correction in the SQ4 int8 KleidiAI packed path; computes/stashes BZpCorr/AFloatBlkSum pointers. |
| onnxruntime/core/mlas/inc/mlas_qnbit.h | Adds BZpCorr and AFloatBlkSum fields to the GEMM data params struct. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Computes BZpCorr during scales PrePack for ARM64 KleidiAI asymmetric path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
- Fix strict aliasing UB in float->bf16 scale conversion: use std::memcpy instead of reinterpret_cast<uint32_t*>. - Gate KleidiAI packed-scales path on constant zero_points: if zero_points is dynamic (TryGetConstantInput returns nullptr), fall back to non-KleidiAI path instead of marking scales as packed with uninitialized BZpCorr. - Document intentional use of fp32 scales (not bf16-truncated) in BZpCorr computation for higher correction precision.
The MLFloat16 PrePack specialization (used on macOS ARM where MLAS_F16VEC_INTRINSICS_SUPPORTED is not defined) was missing the KleidiAI asymmetric BZpCorr computation. B packing passed nullptr for zero_points, leaving BZpCorr uninitialized. The correction GEMM then added garbage values. Fix: compute BZpCorr during input_idx==B in the fp16 PrePack path, where both scales (already converted to fp32) and zero_points (not yet PrePacked) are accessible.
- UseKleidiAI(K, BlkLen, HasZp, ...) had zero callers after the asymmetric support was added. Removed it. - Renamed UseKleidiAIBase() back to UseKleidiAI() since it is now the only variant. It checks KleidiAI eligibility (BlkLen, K, dotprod) without zero-point gating. - Added clarifying comment that BZpCorr/AFloatBlkSum correction condition is only true for the KleidiAI asymmetric path.
Description
Extends the KleidiAI-accelerated
MatMulNBitspath on ARM64 to support asymmetric 4-bit quantization (models with per-block zero points). Previously, KleidiAI was only used for symmetric quantization (!HasZpguard), and asymmetric models fell back to a significantly slower non-KleidiAI kernel.Since KleidiAI only provides symmetric int4 micro-kernels (
kai_matmul_clamp_f32_qai8dxp_qsi4c32p), this PR runs KleidiAI as-if symmetric (hardcodedrhs_zero_point=8) and applies a float-domain zero-point correction post-GEMM:where:
BZpCorr[n, blk] = scale_b[n, blk] × (8 - zp_b[n, blk])— precomputed at weight packing timeAFloatBlkSum[m, blk] = Σ A_float[m, blk_start..blk_end]— computed per inference alongside A quantizationKey changes:
UseKleidiAI(): Removed the!HasZpguard so KleidiAI is now eligible for both symmetric and asymmetric models (checksBlkLen%32==0,K%BlkLen==0, dotprod capability).SQ4BitGemmPackQuantBDataAndBlkSum): Computes and storesBZpCorrafter KleidiAI packed B data when zero points are present.AFloatBlkSum(M × BlockCountK floats) in the per-GEMM workspace for asymmetric models.ComputeAFloatBlkSum: NEON-vectorized (4× unrolledvaddq_f32) function to compute per-block float sums of A.ApplyBZpCorrection: NEON-vectorized correction kernel tiled 4-N-wide (vfmaq_f32) for L1-friendly BZpCorr reuse.BZpCorrduring the scales PrePack (not zero_points PrePack), since ORT may erase constant inputs after marking them packed.No changes to the symmetric path. No changes to x64. No changes to 8-bit quantization.
Motivation and Context
Asymmetric 4-bit quantized models (e.g., GPTQ/RTN with zero points) on ARM64 were 23–72% slower than their symmetric counterparts because KleidiAI's
sdot/i8mmmicro-kernels only support symmetric RHS, forcing a fallback to a slower non-KleidiAI kernel path.This change closes most of that gap:
The remaining 6–22% asym/sym gap comes from the extra pass over A to compute float block sums — this cannot be fused into KleidiAI's sealed A-packing function and would require an upstream KleidiAI API change.