Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,36 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
if (input_idx == InputIndex::scales && packed_b_ != nullptr &&
MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_,
has_zp_input_, &mlas_backend_kernel_selector_config_)) {
// For asymmetric quantization, we require zero_points to be a constant initializer
// in order to safely use the KleidiAI packed-scales path. If zero_points is not
// constant, fall back to the non-KleidiAI path by leaving scales unpacked.
if (has_zp_input_) {
const Tensor* zp_tensor = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor);
if (zp_tensor == nullptr) {
// zero_points is dynamic: do not mark scales as packed so that the
// execution falls back to the non-KleidiAI path.
return Status::OK();
}
}

scales_are_packed_ = true;
is_packed = true;

// For KleidiAI asymmetric 4-bit path: compute BZpCorr now while scales are still accessible.
// After this PrePack returns is_packed=true, ORT may erase scales from the constant
// input table (use count drops to 0), making them unavailable in later PrePack calls.
// Zero points haven't been PrePacked yet so they are still accessible.
if (has_zp_input_ && nbits_ == 4) {
const Tensor* zp_tensor = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor);
if (zp_tensor != nullptr) {
auto sptr = tensor.Data<float>();
auto zptr = zp_tensor->Data<uint8_t>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr,
has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_);
}
Comment thread
jambayk marked this conversation as resolved.
}
}
#endif // MLAS_TARGET_ARM64
} else if (compute_type_ == HQNBIT_CompInt8 && nbits_ == 8) {
Expand Down Expand Up @@ -404,6 +432,22 @@ Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*ou
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(),
scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_);

#if defined(MLAS_TARGET_ARM64)
// For KleidiAI asymmetric 4-bit path: compute BZpCorr during B packing.
// The fp16 specialization packs B here (with scales already converted to fp32),
// so we also compute BZpCorr now while both scales and zero_points are accessible.
if (has_zp_input_ && nbits_ == 4 && scales_fp32_ != nullptr) {
const Tensor* zp_tensor = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor);
if (zp_tensor != nullptr) {
auto zptr = zp_tensor->Data<uint8_t>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(),
scales_fp32_.get(), has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_);
}
}
#endif // MLAS_TARGET_ARM64

is_packed = true;
} else if (compute_type_ == SQNBIT_CompInt8) {
bool should_pack_scale_and_zp = [&]() {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS {

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<T>* PostProcessor = nullptr;

const float* BZpCorr = nullptr; ///< optional: BZpCorrection for KleidiAI asymmetric path (N * BlockCountK floats)
const float* AFloatBlkSum = nullptr; ///< optional: float-domain A block sums for KleidiAI asymmetric path (M * BlockCountK floats)
};

/**
Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,34 @@ SQ4BitGemm_CompInt8(
SQ4BitGemm(BlkLen, QuantA, DataParams->PackedQuantBData,
DataParams->C, RangeStartM, RangeCountM, RangeStartN, RangeCountN, K,
DataParams->ldc, DataParams->Bias);

// Apply zero-point correction for asymmetric quantization (KleidiAI path only).
// BZpCorr and AFloatBlkSum are only set when KleidiAI is active with asymmetric
// quantization (has zero points). On all other paths they remain nullptr.
// C += AFloatBlkSum * BZpCorr^T (for this tile's M/N ranges)
if (DataParams->BZpCorr != nullptr && DataParams->AFloatBlkSum != nullptr) {
Comment thread
jambayk marked this conversation as resolved.
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t ldc = DataParams->ldc;
const float* ABlkSum = DataParams->AFloatBlkSum + RangeStartM * BlockCountK;
const float* BCorr = DataParams->BZpCorr + RangeStartN * BlockCountK;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;

const auto ApplyCorrection = GetMlasPlatform().QNBitGemmDispatch->ApplyBZpCorrection;
if (ApplyCorrection) {
ApplyCorrection(ABlkSum, BCorr, C, RangeCountM, RangeCountN, BlockCountK, ldc);
} else {
// Scalar fallback
for (size_t m = 0; m < RangeCountM; ++m) {
for (size_t n = 0; n < RangeCountN; ++n) {
float corr = 0.0f;
for (size_t blk = 0; blk < BlockCountK; ++blk) {
corr += ABlkSum[m * BlockCountK + blk] * BCorr[n * BlockCountK + blk];
}
C[m * ldc + n] += corr;
}
}
}
}
return;
}

Expand Down Expand Up @@ -1186,6 +1214,7 @@ InitializeWorkspace_CompInt8<float>(

const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8;
const auto QuantizeA_Packed = GetMlasPlatform().QNBitGemmDispatch->QuantizeA_Packed_CompInt8;
const auto ComputeAFloatBlkSumFn = GetMlasPlatform().QNBitGemmDispatch->ComputeAFloatBlkSum;
const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8;
const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8;

Expand All @@ -1194,12 +1223,28 @@ InitializeWorkspace_CompInt8<float>(

// TODO: try parallel on BatchN * M threads because BatchN is usually 1.
if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) {
// Compute KleidiAI packed A size (same as workspace size without zero points)
const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize
? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize(
M, N, K, BlkLen, /*HasZeroPoint=*/false, SQNBIT_CompInt8, BlkBitWidth, BackendKernelSelectorConfig)
: 0;

MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];

const float* ARowPtr = data.A;
std::byte* QuantARowPtr = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr, BackendKernelSelectorConfig);

// For asymmetric KleidiAI path, also compute float-domain A block sums
// for zero-point correction. AFloatBlkSum is stored after KleidiAI packed A.
if (data.QuantBZeroPoint != nullptr && ComputeAFloatBlkSumFn != nullptr && kleidiAIPackedASize > 0) {
// Align offset so AFloatBlkSum starts at a float-aligned address
constexpr size_t FloatAlignment = alignof(float);
const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1);
float* AFloatBlkSum = reinterpret_cast<float*>(QuantARowPtr + alignedAOffset);
ComputeAFloatBlkSumFn(ARowPtr, M, K, BlkLen, data.lda, AFloatBlkSum);
}
});
} else {
// TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms
Expand Down Expand Up @@ -1482,6 +1527,45 @@ MlasQNBitGemmBatch(

const size_t BlockCountK = MlasDivRoundup(K, BlkLen);

// For KleidiAI asymmetric path: set up BZpCorr and AFloatBlkSum pointers.
// BZpCorr is stored after KleidiAI packed B data.
// AFloatBlkSum is stored after KleidiAI packed A data in each per-GEMM workspace.
const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8;
if (Variant == SQ4BitGemmVariant_CompInt8 && has_zp_input && UsePacked &&
UsePacked(K, BlkLen, DataParams->QuantBZeroPoint, BackendKernelSelectorConfig)) {
// Compute KleidiAI packed B size (without zero point correction space)
const size_t kleidiAIPackedBSize = GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize
? GetMlasPlatform().QNBitGemmDispatch->Q4BitGemmPackQuantBDataSize(
N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BackendKernelSelectorConfig)
: 0;
// KleidiAI packed A size (workspace without AFloatBlkSum)
const size_t kleidiAIPackedASize = GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize
? GetMlasPlatform().QNBitGemmDispatch->QNBitGemmPerGemmWorkspaceSize(
M, N, K, BlkLen, /*HasZeroPoint=*/false, ComputeType, BlkBitWidth, BackendKernelSelectorConfig)
: 0;

// Align offsets so float arrays start at float-aligned addresses
constexpr size_t FloatAlignment = alignof(float);
const size_t alignedBOffset = (kleidiAIPackedBSize + FloatAlignment - 1) & ~(FloatAlignment - 1);
const size_t alignedAOffset = (kleidiAIPackedASize + FloatAlignment - 1) & ~(FloatAlignment - 1);

for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
auto* Data = const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(&DataParams[gemm_i]);
if (Data->QuantBZeroPoint == nullptr) {
continue;
}
// BZpCorr is at the end of packed B data (float-aligned)
if (kleidiAIPackedBSize > 0 && Data->PackedQuantBData != nullptr) {
Data->BZpCorr = reinterpret_cast<const float*>(Data->PackedQuantBData + alignedBOffset);
}
Comment thread
jambayk marked this conversation as resolved.
// AFloatBlkSum is at the end of the workspace's KleidiAI packed A (float-aligned)
if (kleidiAIPackedASize > 0 && Workspace != nullptr) {
std::byte* wsBase = reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
Data->AFloatBlkSum = reinterpret_cast<const float*>(wsBase + alignedAOffset);
}
}
}

if (ThreadPool == nullptr) {
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
const auto* Data = &DataParams[gemm_i];
Expand Down
46 changes: 46 additions & 0 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,52 @@ struct MLAS_QNBIT_GEMM_DISPATCH {

QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr;

/**
* @brief Compute float-domain block sums of A for zero-point correction.
* Used when KleidiAI handles asymmetric quantization.
*
* @param A Supplies the float A matrix.
* @param CountM Number of rows of A.
* @param CountK Number of columns of A.
* @param BlkLen Number of values in a block.
* @param lda Leading dimension of A.
* @param[out] AFloatBlkSum Output: M * BlockCountK float sums.
*/
typedef void(ComputeAFloatBlkSum_Fn)(
const float* A,
size_t CountM,
size_t CountK,
size_t BlkLen,
size_t lda,
float* AFloatBlkSum
);

ComputeAFloatBlkSum_Fn* ComputeAFloatBlkSum = nullptr;

/**
* @brief Apply zero-point correction: C += ABlkSum * BCorr^T
* Used after KleidiAI GEMM for asymmetric quantization.
*
* @param ABlkSum Float block sums of A, [RangeCountM, BlockCountK] row-major.
* @param BCorr BZpCorrection, [RangeCountN, BlockCountK] row-major (pre-offset).
* @param[out] C Output matrix tile (pre-offset), accumulated.
* @param RangeCountM Number of M rows in this tile.
* @param RangeCountN Number of N columns in this tile.
* @param BlockCountK Number of blocks along K.
* @param ldc Leading dimension of C.
*/
typedef void(ApplyBZpCorrection_Fn)(
const float* ABlkSum,
const float* BCorr,
float* C,
size_t RangeCountM,
size_t RangeCountN,
size_t BlockCountK,
size_t ldc
);

ApplyBZpCorrection_Fn* ApplyBZpCorrection = nullptr;

/**
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.
*
Expand Down
Loading
Loading