Skip to content

Enable KleidiAI for asymmetric 4-bit MatMulNBits on ARM64#27751

Merged
jambayk merged 10 commits intomainfrom
jambayk/mnb-arm-32
Mar 20, 2026
Merged

Enable KleidiAI for asymmetric 4-bit MatMulNBits on ARM64#27751
jambayk merged 10 commits intomainfrom
jambayk/mnb-arm-32

Conversation

@jambayk
Copy link
Copy Markdown
Contributor

@jambayk jambayk commented Mar 18, 2026

Description

Extends the KleidiAI-accelerated MatMulNBits path on ARM64 to support asymmetric 4-bit quantization (models with per-block zero points). Previously, KleidiAI was only used for symmetric quantization (!HasZp guard), and asymmetric models fell back to a significantly slower non-KleidiAI kernel.

Since KleidiAI only provides symmetric int4 micro-kernels (kai_matmul_clamp_f32_qai8dxp_qsi4c32p), this PR runs KleidiAI as-if symmetric (hardcoded rhs_zero_point=8) and applies a float-domain zero-point correction post-GEMM:

$$C_{\text{actual}} = C_{\text{symmetric}} + A_{\text{blksum}} \times B_{\text{ZpCorr}}^T$$

where:

  • BZpCorr[n, blk] = scale_b[n, blk] × (8 - zp_b[n, blk]) — precomputed at weight packing time
  • AFloatBlkSum[m, blk] = Σ A_float[m, blk_start..blk_end] — computed per inference alongside A quantization

Key changes:

  • UseKleidiAI(): Removed the !HasZp guard so KleidiAI is now eligible for both symmetric and asymmetric models (checks BlkLen%32==0, K%BlkLen==0, dotprod capability).
  • B packing (SQ4BitGemmPackQuantBDataAndBlkSum): Computes and stores BZpCorr after KleidiAI packed B data when zero points are present.
  • Workspace expansion: Allocates space for AFloatBlkSum (M × BlockCountK floats) in the per-GEMM workspace for asymmetric models.
  • ComputeAFloatBlkSum: NEON-vectorized (4× unrolled vaddq_f32) function to compute per-block float sums of A.
  • ApplyBZpCorrection: NEON-vectorized correction kernel tiled 4-N-wide (vfmaq_f32) for L1-friendly BZpCorr reuse.
  • PrePack: Computes BZpCorr during the scales PrePack (not zero_points PrePack), since ORT may erase constant inputs after marking them packed.

No changes to the symmetric path. No changes to x64. No changes to 8-bit quantization.

Motivation and Context

Asymmetric 4-bit quantized models (e.g., GPTQ/RTN with zero points) on ARM64 were 23–72% slower than their symmetric counterparts because KleidiAI's sdot/i8mm micro-kernels only support symmetric RHS, forcing a fallback to a slower non-KleidiAI kernel path.

This change closes most of that gap:

Model Seq Len Asym/Sym (before) Asym/Sym (after) Asym speedup Asym latency (after)
Qwen 1.5B 256 1.35× 1.17× 1.16× 1107.8ms
Qwen 1.5B 512 1.23× 1.06× 1.14× 2259.7ms
Qwen 3B 256 1.43× 1.12× 1.28× 2029.7ms
Qwen 3B 512 1.39× 1.22× 1.24× 4188.0ms
Qwen 7B 256 1.61× 1.11× 1.52× 3661.6ms
Qwen 7B 512 1.72× 1.11× 1.58× 7263.8ms

The remaining 6–22% asym/sym gap comes from the extra pass over A to compute float block sums — this cannot be fused into KleidiAI's sealed A-packing function and would require an upstream KleidiAI API change.

@jambayk jambayk requested a review from Copilot March 18, 2026 23:20
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the ARM64 KleidiAI-accelerated MatMulNBits (SQNBIT int8 compute, int4 weights) path to support asymmetric 4-bit quantization by running KleidiAI in a symmetric mode and applying a post-GEMM floating-point zero-point correction.

Changes:

  • Introduces UseKleidiAIBase() to determine KleidiAI eligibility independent of HasZp, enabling KleidiAI for asymmetric int4 weights.
  • Adds packing-time precomputation/storage of BZpCorr and runtime computation of AFloatBlkSum, plus a NEON correction kernel C += ABlkSum * BCorr^T.
  • Expands per-GEMM workspace and dispatch plumbing to carry/apply the correction data.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp Adds NEON kernels to compute A block sums and apply the asymmetric ZP correction; updates packed-path selection to use UseKleidiAIBase.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h Declares new helper functions (ComputeAFloatBlkSum, ApplyBZpCorrection) and UseKleidiAIBase.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Updates KleidiAI packing-size/workspace sizing and computes/stores BZpCorr adjacent to KleidiAI packed-B.
onnxruntime/core/mlas/lib/qnbitgemm.h Extends MLAS_QNBIT_GEMM_DISPATCH with function pointers for A-sum computation and correction application.
onnxruntime/core/mlas/lib/qnbitgemm.cpp Wires correction into the KleidiAI packed SQ4 path and sets up runtime pointers into packed-B / per-GEMM workspace.
onnxruntime/core/mlas/inc/mlas_qnbit.h Extends MLAS_QNBIT_GEMM_DATA_PARAMS with BZpCorr and AFloatBlkSum pointers.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Computes BZpCorr during scales PrePack for ARM64 KleidiAI packed path to avoid constant-input erasure issues.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
Comment thread onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp
- Align BZpCorr and AFloatBlkSum offsets to alignof(float) to avoid
  undefined behavior from misaligned float pointer dereferences.
  KleidiAI packed sizes may not be multiples of 4.
- Gate BZpCorr/AFloatBlkSum pointer setup on per-entry
  Data->QuantBZeroPoint != nullptr (skip entries without zero points).
- Assert QuantBScaleBegin != nullptr when QuantBDataBegin is provided
  in KleidiAI packing path (scales are required for RHS packing).
@jambayk jambayk requested a review from Copilot March 18, 2026 23:56
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Extends the ARM64 KleidiAI-accelerated MatMulNBits (SQ4 int4) path to support asymmetric 4-bit quantization by running KleidiAI in a symmetric configuration and applying an additional float-domain zero-point correction.

Changes:

  • Enables KleidiAI packed path selection for int4 even when RHS zero-points are present (via UseKleidiAIBase).
  • Adds computation of AFloatBlkSum and applies C += ABlkSum * BZpCorr^T after KleidiAI GEMM.
  • Extends B packing and per-GEMM workspace sizing/layout to store and use BZpCorr and AFloatBlkSum.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp Adds NEON implementations for ComputeAFloatBlkSum and ApplyBZpCorrection; updates packed-path selection to use UseKleidiAIBase.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h Declares new helper functions and UseKleidiAIBase.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Updates packing size logic and B packing to optionally compute/store BZpCorr; introduces UseKleidiAIBase.
onnxruntime/core/mlas/lib/qnbitgemm.h Extends dispatch struct with function pointers for ComputeAFloatBlkSum and ApplyBZpCorrection.
onnxruntime/core/mlas/lib/qnbitgemm.cpp Wires up correction in the SQ4 int8 KleidiAI packed path; computes/stashes BZpCorr/AFloatBlkSum pointers.
onnxruntime/core/mlas/inc/mlas_qnbit.h Adds BZpCorr and AFloatBlkSum fields to the GEMM data params struct.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Computes BZpCorr during scales PrePack for ARM64 KleidiAI asymmetric path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Comment thread onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
Comment thread onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
jambayk added 2 commits March 19, 2026 00:49
- Fix strict aliasing UB in float->bf16 scale conversion: use
  std::memcpy instead of reinterpret_cast<uint32_t*>.
- Gate KleidiAI packed-scales path on constant zero_points: if
  zero_points is dynamic (TryGetConstantInput returns nullptr),
  fall back to non-KleidiAI path instead of marking scales as
  packed with uninitialized BZpCorr.
- Document intentional use of fp32 scales (not bf16-truncated) in
  BZpCorr computation for higher correction precision.
The MLFloat16 PrePack specialization (used on macOS ARM where
MLAS_F16VEC_INTRINSICS_SUPPORTED is not defined) was missing the
KleidiAI asymmetric BZpCorr computation. B packing passed nullptr
for zero_points, leaving BZpCorr uninitialized. The correction
GEMM then added garbage values.

Fix: compute BZpCorr during input_idx==B in the fp16 PrePack path,
where both scales (already converted to fp32) and zero_points
(not yet PrePacked) are accessible.
@jambayk jambayk enabled auto-merge (squash) March 19, 2026 16:47
@jambayk jambayk requested a review from hariharans29 March 19, 2026 16:48
Comment thread onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
Comment thread onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp
- UseKleidiAI(K, BlkLen, HasZp, ...) had zero callers after the
  asymmetric support was added. Removed it.
- Renamed UseKleidiAIBase() back to UseKleidiAI() since it is now
  the only variant. It checks KleidiAI eligibility (BlkLen, K,
  dotprod) without zero-point gating.
- Added clarifying comment that BZpCorr/AFloatBlkSum correction
  condition is only true for the KleidiAI asymmetric path.
@jambayk jambayk requested a review from hariharans29 March 19, 2026 22:19
@jambayk jambayk merged commit d9bb760 into main Mar 20, 2026
91 of 92 checks passed
@jambayk jambayk deleted the jambayk/mnb-arm-32 branch March 20, 2026 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants