Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,21 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
has_zp_input_, &mlas_backend_kernel_selector_config_)) {
scales_are_packed_ = true;
is_packed = true;

// For KleidiAI asymmetric path: compute BZpCorr now while scales are still accessible.
// After this PrePack returns is_packed=true, ORT may erase scales from the constant
// input table (use count drops to 0), making them unavailable in later PrePack calls.
// Zero points haven't been PrePacked yet so they are still accessible.
if (has_zp_input_ && nbits_ == 4) {
const Tensor* zp_tensor = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor);
if (zp_tensor != nullptr) {
auto sptr = tensor.Data<float>();
auto zptr = zp_tensor->Data<uint8_t>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr,
has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_);
}
Comment thread
jambayk marked this conversation as resolved.
}
}
#endif // MLAS_TARGET_ARM64
} else if (compute_type_ == HQNBIT_CompInt8 && nbits_ == 8) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS {

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<T>* PostProcessor = nullptr;

const float* BZpCorr = nullptr; ///< optional: BZpCorrection for KleidiAI asymmetric path (N * BlockCountK floats)
const float* AFloatBlkSum = nullptr; ///< optional: float-domain A block sums for KleidiAI asymmetric path (M * BlockCountK floats)
};

/**
Expand Down
82 changes: 82 additions & 0 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,32 @@ SQ4BitGemm_CompInt8(
SQ4BitGemm(BlkLen, QuantA, DataParams->PackedQuantBData,
DataParams->C, RangeStartM, RangeCountM, RangeStartN, RangeCountN, K,
DataParams->ldc, DataParams->Bias);

// Apply zero-point correction for asymmetric quantization:
// C += AFloatBlkSum * BZpCorr^T (for this tile's M/N ranges)
if (DataParams->BZpCorr != nullptr && DataParams->AFloatBlkSum != nullptr) {
Comment thread
jambayk marked this conversation as resolved.
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t ldc = DataParams->ldc;
const float* ABlkSum = DataParams->AFloatBlkSum + RangeStartM * BlockCountK;
const float* BCorr = DataParams->BZpCorr + RangeStartN * BlockCountK;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;

const auto ApplyCorrection = GetMlasPlatform().QNBitGemmDispatch->ApplyBZpCorrection;
if (ApplyCorrection) {
ApplyCorrection(ABlkSum, BCorr, C, RangeCountM, RangeCountN, BlockCountK, ldc);
} else {
// Scalar fallback
for (size_t m = 0; m < RangeCountM; ++m) {
for (size_t n = 0; n < RangeCountN; ++n) {
float corr = 0.0f;
for (size_t blk = 0; blk < BlockCountK; ++blk) {
corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk];
}
C[m * ldc + n] += corr;
}
}
}
}
return;
}

Expand Down Expand Up @@ -1186,6 +1212,7 @@ InitializeWorkspace_CompInt8<float>(

const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8;
const auto QuantizeA_Packed = GetMlasPlatform().QNBitGemmDispatch->QuantizeA_Packed_CompInt8;
const auto ComputeAFloatBlkSumFn = GetMlasPlatform().QNBitGemmDispatch->ComputeAFloatBlkSum;
const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8;
const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8;

Expand All @@ -1194,12 +1221,28 @@ InitializeWorkspace_CompInt8<float>(

// TODO: try parallel on BatchN * M threads because BatchN is usually 1.
if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) {
// Compute KleidiAI packed A size (same as workspace size without zero points)
const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize
? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize(
M, N, K, BlkLen, /*HasZeroPoint=*/false, SQNBIT_CompInt8, BlkBitWidth, BackendKernelSelectorConfig)
: 0;

MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];

const float* ARowPtr = data.A;
std::byte* QuantARowPtr = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr, BackendKernelSelectorConfig);

// For asymmetric KleidiAI path, also compute float-domain A block sums
// for zero-point correction. AFloatBlkSum is stored after KleidiAI packed A.
if (data.QuantBZeroPoint != nullptr && ComputeAFloatBlkSumFn != nullptr && kleidiAIPackedASize > 0) {
// Align offset so AFloatBlkSum starts at a float-aligned address
constexpr size_t FloatAlignment = alignof(float);
const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1);
float* AFloatBlkSum = reinterpret_cast<float*>(QuantARowPtr + alignedAOffset);
ComputeAFloatBlkSumFn(ARowPtr, M, K, BlkLen, data.lda, AFloatBlkSum);
}
});
} else {
// TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms
Expand Down Expand Up @@ -1482,6 +1525,45 @@ MlasQNBitGemmBatch(

const size_t BlockCountK = MlasDivRoundup(K, BlkLen);

// For KleidiAI asymmetric path: set up BZpCorr and AFloatBlkSum pointers.
// BZpCorr is stored after KleidiAI packed B data.
// AFloatBlkSum is stored after KleidiAI packed A data in each per-GEMM workspace.
const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8;
if (Variant == SQ4BitGemmVariant_CompInt8 && has_zp_input && UsePacked &&
UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) {
// Compute KleidiAI packed B size (without zero point correction space)
const size_t kleidiAIPackedBSize = GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize
? GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize(
N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BackendKernelSelectorConfig)
: 0;
// KleidiAI packed A size (workspace without AFloatBlkSum)
const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize
? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize(
M, N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BlkBitWidth, BackendKernelSelectorConfig)
: 0;

// Align offsets so float arrays start at float-aligned addresses
constexpr size_t FloatAlignment = alignof(float);
const size_t alignedBOffset = (kleidiAIPackedBSize + FloatAlignment - 1) & ~(FloatAlignment - 1);
const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1);

for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
auto* Data = const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(&DataParams[gemm_i]);
if (Data->QuantBZeroPoint == nullptr) {
continue;
}
// BZpCorr is at the end of packed B data (float-aligned)
if (kleidiAIPackedBSize > 0 && Data->PackedQuantBData != nullptr) {
Data->BZpCorr = reinterpret_cast<const float*>(Data->PackedQuantBData + alignedBOffset);
}
Comment thread
jambayk marked this conversation as resolved.
// AFloatBlkSum is at the end of the workspace's KleidiAI packed A (float-aligned)
if (kleidiAIPackedASize > 0 && Workspace != nullptr) {
std::byte* wsBase = reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
Data->AFloatBlkSum = reinterpret_cast<const float*>(wsBase + alignedAOffset);
}
}
}

if (ThreadPool == nullptr) {
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
const auto* Data = &DataParams[gemm_i];
Expand Down
46 changes: 46 additions & 0 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,52 @@ struct MLAS_QNBIT_GEMM_DISPATCH {

QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr;

/**
* @brief Compute float-domain block sums of A for zero-point correction.
* Used when KleidiAI handles asymmetric quantization.
*
* @param A Supplies the float A matrix.
* @param CountM Number of rows of A.
* @param CountK Number of columns of A.
* @param BlkLen Number of values in a block.
* @param lda Leading dimension of A.
* @param[out] AFloatBlkSum Output: M * BlockCountK float sums.
*/
typedef void(ComputeAFloatBlkSum_Fn)(
const float* A,
size_t CountM,
size_t CountK,
size_t BlkLen,
size_t lda,
float* AFloatBlkSum
);

ComputeAFloatBlkSum_Fn* ComputeAFloatBlkSum = nullptr;

/**
* @brief Apply zero-point correction: C += ABlkSum * BCorr^T
* Used after KleidiAI GEMM for asymmetric quantization.
*
* @param ABlkSum Float block sums of A, [RangeCountM, BlockCountK] row-major.
* @param BCorr BZpCorrection, [RangeCountN, BlockCountK] row-major (pre-offset).
* @param[out] C Output matrix tile (pre-offset), accumulated.
* @param RangeCountM Number of M rows in this tile.
* @param RangeCountN Number of N columns in this tile.
* @param BlockCountK Number of blocks along K.
* @param ldc Leading dimension of C.
*/
typedef void(ApplyBZpCorrection_Fn)(
const float* ABlkSum,
const float* BCorr,
float* C,
size_t RangeCountM,
size_t RangeCountN,
size_t BlockCountK,
size_t ldc
);

ApplyBZpCorrection_Fn* ApplyBZpCorrection = nullptr;

/**
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.
*
Expand Down
98 changes: 75 additions & 23 deletions onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,22 @@ QNBitGemmPackQuantBDataSize(
#endif

#ifdef USE_KLEIDIAI
if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) {
if (ComputeType == SQNBIT_CompInt8 && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) {
const auto& k = GetKleidiAIGemmUKernel();
const auto& ukernel = k.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16);
size_t packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16);
if (HasZeroPoint) {
// Align so that BZpCorrection starts at a float-aligned offset
constexpr size_t FloatAlignment = alignof(float);
packed_size = (packed_size + FloatAlignment - 1) & ~(FloatAlignment - 1);
// Additional space for BZpCorrection: N * BlockCountK floats
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
packed_size += N * BlockCountK * sizeof(float);
Comment thread
jambayk marked this conversation as resolved.
}
return packed_size;
} else
#endif
{
Expand Down Expand Up @@ -183,7 +192,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum(
const std::byte* QuantBDataBegin,
const float* QuantBScaleBegin,
bool HasZeroPoint,
const std::byte*,
const std::byte* QuantBZPBegin,
PackedQuantBDataStruct<float, 4>& PackedQuantB,
MLAS_THREADPOOL* ThreadPool,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
Expand All @@ -192,36 +201,63 @@ SQ4BitGemmPackQuantBDataAndBlkSum(
#ifndef USE_KLEIDIAI
MLAS_UNREFERENCED_PARAMETER(QuantBScaleBegin);
MLAS_UNREFERENCED_PARAMETER(HasZeroPoint);
MLAS_UNREFERENCED_PARAMETER(QuantBZPBegin);
#endif
assert(BlkLen >= 16 && BlkLen % 16 == 0);

#ifdef USE_KLEIDIAI
if (UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) {
if (UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) {
const auto& k = GetKleidiAIGemmUKernel();
const auto& ukernel = k.ukernel;
std::byte* PackedQuantBDataBegin = PackedQuantB.PackedQuantBData;

const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);

kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
params.scale_dt = kai_dt_bf16;
// Pack B data with KleidiAI (only when B data is provided)
if (QuantBDataBegin != nullptr) {
assert(QuantBScaleBegin != nullptr);
kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
params.scale_dt = kai_dt_bf16;

const size_t scales_len = N * BlockCountK;
std::vector<uint16_t> scales(scales_len);
for (size_t i = 0; i < scales_len; i++) {
const uint32_t* i32 = reinterpret_cast<const uint32_t*>(&QuantBScaleBegin[i]);
scales[i] = *i32 >> 16;
}
Comment thread
jambayk marked this conversation as resolved.

const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t scales_len = N * BlockCountK;
std::vector<uint16_t> scales(scales_len);
for (size_t i = 0; i < scales_len; i++) {
const uint32_t* i32 = reinterpret_cast<const uint32_t*>(&QuantBScaleBegin[i]);
scales[i] = *i32 >> 16;
kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen,
reinterpret_cast<const uint8_t*>(QuantBDataBegin), BlockCountK * BlkLen / 2,
nullptr, scales.data(), BlockCountK * sizeof(uint16_t),
PackedQuantBDataBegin, 0, &params);
}

kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen,
reinterpret_cast<const uint8_t*>(QuantBDataBegin), BlockCountK * BlkLen / 2,
nullptr, scales.data(), BlockCountK * sizeof(uint16_t),
PackedQuantBDataBegin, 0, &params);
// Compute BZpCorrection when both scales and zero points are available.
// BZpCorr[n * BlockCountK + blk] = scale_b * (8 - zp_b)
// This may be called separately from B packing when zero points arrive later.
if (HasZeroPoint && QuantBZPBegin != nullptr && QuantBScaleBegin != nullptr) {
const size_t kleidiai_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
N, K, nr, kr, sr, BlkLen, kai_dt_bf16);
// Align offset so BZpCorr starts at a float-aligned address
constexpr size_t FloatAlignment = alignof(float);
const size_t bzpcorr_offset = (kleidiai_packed_size + FloatAlignment - 1) & ~(FloatAlignment - 1);
Comment thread
jambayk marked this conversation as resolved.
float* BZpCorr = reinterpret_cast<float*>(PackedQuantBDataBegin + bzpcorr_offset);

for (size_t n = 0; n < N; ++n) {
for (size_t blk = 0; blk < BlockCountK; ++blk) {
const size_t idx = n * BlockCountK + blk;
const size_t zp_byte_idx = blk / 2;
const uint8_t zp_byte = static_cast<uint8_t>(QuantBZPBegin[n * MlasDivRoundup(BlockCountK, 2) + zp_byte_idx]);
const uint8_t zp = (blk & 1) == 0 ? (zp_byte & 0x0F) : (zp_byte >> 4);
BZpCorr[idx] = QuantBScaleBegin[idx] * (8.0f - static_cast<float>(zp));
}
Comment thread
jambayk marked this conversation as resolved.
}
}
} else
#endif
{
Expand Down Expand Up @@ -419,14 +455,23 @@ QNBitGemmPerGemmWorkspaceSize(
case SQNBIT_CompInt8: {
// workspace buffer is used for block quantization of A to int8
#ifdef USE_KLEIDIAI
if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, HasZeroPoint, BackendKernelSelectorConfig)) {
if (BlkBitWidth == 4 && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig)) {
const auto& k = (M == 1) ? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel();
const auto& ukernel = k.ukernel;

const size_t mr = ukernel.get_mr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
size_t ws = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
if (HasZeroPoint) {
// Align so that AFloatBlkSum starts at a float-aligned offset
constexpr size_t FloatAlignment = alignof(float);
ws = (ws + FloatAlignment - 1) & ~(FloatAlignment - 1);
// Additional space for AFloatBlkSum: M * BlockCountK floats
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
ws += M * BlockCountK * sizeof(float);
}
return ws;
} else
#endif
{
Expand Down Expand Up @@ -471,24 +516,29 @@ QNBitGemmPerGemmWorkspaceAlignment(
} // namespace

bool
UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig)
UseKleidiAIBase(size_t K, size_t BlkLen, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig)
{
#ifdef USE_KLEIDIAI
if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) {
return false;
}

bool has_dotprod = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot();
return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && !HasZp && has_dotprod;
return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && has_dotprod;
#else
MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig);
MLAS_UNREFERENCED_PARAMETER(K);
MLAS_UNREFERENCED_PARAMETER(BlkLen);
MLAS_UNREFERENCED_PARAMETER(HasZp);
return false;
#endif
}

bool
UseKleidiAI(size_t K, size_t BlkLen, bool HasZp, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig)
{
return !HasZp && UseKleidiAIBase(K, BlkLen, BackendKernelSelectorConfig);
Comment thread
jambayk marked this conversation as resolved.
Outdated
}

template<bool QuantAUnsigned>
size_t
SQ8BitGemmKernel_BlkSum_CompInt8(
Expand Down Expand Up @@ -601,6 +651,8 @@ GetMlasQNBitGemmDispatchNeon(
#ifdef USE_KLEIDIAI
d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8;
d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8;
d.ComputeAFloatBlkSum = sqnbitgemm_neon::ComputeAFloatBlkSum;
d.ApplyBZpCorrection = sqnbitgemm_neon::ApplyBZpCorrection;
#endif
}

Expand Down
Loading
Loading