Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
517e166
Implement multithreading in qgemm_kleidi
melkap01-Arm Oct 14, 2025
bd05fce
fixes addressed:
melkap01-Arm Oct 30, 2025
75fee7a
lhs_base_table buffer implemented inside TLS
melkap01-Arm Oct 31, 2025
e53e67b
multithreaded qgemms coverage with single-multi threaded
melkap01-Arm Nov 3, 2025
c2428fb
Test commit damdoo01
damdoo01-arm Nov 10, 2025
0479aea
Undo Test commit damdoo01
damdoo01-arm Nov 10, 2025
9ef3b4c
SME2 test case check brought on the Test() function, after rebase
melkap01-Arm Nov 28, 2025
32bf43c
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Dec 4, 2025
b6ff3be
Dynamic Qgemm Prepack() refactored
melkap01-Arm Dec 4, 2025
7ed0e5c
Merge branch 'main' into melkap01_implement_mt_qgemm
melkap01-Arm Dec 5, 2025
76ba64f
Quant Kernel log added, include corrected in dynamic qgemm test
melkap01-Arm Dec 5, 2025
6bac9a5
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Dec 10, 2025
f1605e5
provider test cases for keldiai dynamic qgemms added
melkap01-Arm Dec 10, 2025
0c3748b
-Arm KleidiAI helper methods in Mlas space commented.
melkap01-Arm Dec 15, 2025
99fe8c5
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Dec 18, 2025
47e4c92
KleidiAI dynamic quantization supported by promoting 1D B tensor to 2D
melkap01-Arm Dec 19, 2025
fb8eefb
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Dec 22, 2025
d9a26bf
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Jan 2, 2026
cd80e56
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Jan 6, 2026
6356e68
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Jan 6, 2026
50dddaf
lintrunner issue fixed
melkap01-Arm Jan 6, 2026
017a425
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Jan 6, 2026
2ad388c
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Jan 7, 2026
8dc8bc3
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Jan 8, 2026
e000f04
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
melkap01-Arm Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 17 additions & 134 deletions onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,132 +164,23 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
Status Compute(OpKernelContext* context) const override;

#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override {
// only pack Matrix B
if (input_idx == GetBIdx()) {
const Tensor* b_zp_constant_tensor{nullptr};
bool b_quantization_might_be_asymmetric = false;

const OrtValue* b_zp;
if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) {
b_zp_constant_tensor = &b_zp->Get<Tensor>();
}

// MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros
// or not provided.
if (b_zp_constant_tensor != nullptr) {
// B zero point is constant. Check if it is all zeros.
assert(b_zp_constant_tensor->IsDataType<uint8_t>() || b_zp_constant_tensor->IsDataType<int8_t>());
const auto* zp_bytes = static_cast<const std::byte*>(b_zp_constant_tensor->DataRaw());
const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes();
b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes,
[](std::byte v) { return v != std::byte{0}; });
} else {
// B zero point input is not constant. If it exists, we can't assume symmetric quantization.
const auto input_defs = Info().node().InputDefs();
const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists();
b_quantization_might_be_asymmetric = b_zp_input_exists;
}

// MlasDynamicQgemm requires scale data to be available at packing stage
const Tensor* b_scale_tensor = nullptr;
const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor);

can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available);

// Kleidi dynamic path requires strictly positive, finite scales.
// Disable if any invalid scale is detected.
if (can_use_dynamic_quant_mlas_) {
const auto bs = b_scale_tensor->DataAsSpan<float>();
const bool has_invalid =
std::any_of(bs.begin(), bs.end(),
[](float s) { return !std::isfinite(s) || s <= 0.0f; });

if (has_invalid) {
can_use_dynamic_quant_mlas_ = false;
}
}

if (!MlasIsDynamicQGemmAvailable()) {
can_use_dynamic_quant_mlas_ = false;
}

// Only handle the common case of a 2D weight matrix. Additional matrices
// could be handled by stacking the packed buffers.
b_shape_ = tensor.Shape();
if (b_shape_.NumDimensions() >= 2) {
for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) {
if (b_shape_[i] != 1) {
can_use_dynamic_quant_mlas_ = false;
break;
}
}
} else {
can_use_dynamic_quant_mlas_ = false;
}

// Can we use the mlas dynamic Q gemm interface supported with float output ?
if (!can_use_dynamic_quant_mlas_) {
// default to piece wise mlas interface with separate int matmul, quantize and float conversion
return MatMulIntegerToFloatBase::PrePack(tensor, input_idx, alloc, is_packed, prepacked_weights);
}
is_packed = false;

// Default to all zeros for bias
const Tensor* bias_tensor{nullptr};
const OrtValue* bias;
if (Info().TryGetConstantInput(IN_BIAS, &bias)) {
bias_tensor = &bias->Get<Tensor>();
dynamic_quant_mlas_bias_data_was_packed_ = true;
}
size_t K = static_cast<size_t>(b_shape_[0]);
size_t N = static_cast<size_t>(b_shape_[1]);

const auto* b_data = static_cast<const uint8_t*>(tensor.DataRaw());

std::optional<Tensor> b_trans_buffer;
if (IsBTransposed()) {
std::swap(K, N);
b_data = quantization::TransPoseInputData(b_data, b_trans_buffer, alloc, N, K);
}
bool SupportsKleidiaiDynamicQuant() const override {
if (!MlasIsDynamicQGemmAvailable()) {
return false;
}
return true;
}

const size_t packed_b_size = MlasDynamicQgemmPackBSize(N, K);
if (packed_b_size == 0) {
return Status::OK();
}
int GetBScaleIdx() const override {
return IN_B_SCALE;
}

packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size, true);
// Initialize memory to 0 as there could be some padding associated with pre-packed
// buffer memory and we do not want it uninitialized and generate different hashes
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_b_.get(), 0, packed_b_size);

const auto scales = static_cast<size_t>(b_scale_tensor->Shape().Size()) == N ? std::vector<float>(&b_scale_tensor->Data<float>()[0],
&b_scale_tensor->Data<float>()[N])
:
// Broadcast matrix scale to all channels
std::vector<float>(N, b_scale_tensor->Data<float>()[0]);

const auto biases = bias_tensor != nullptr ? std::vector<float>(&bias_tensor->Data<float>()[0],
&bias_tensor->Data<float>()[N])
:
// Broadcast zero to all channels - no bias data is available
std::vector<float>(N, 0.f);

MlasDynamicQgemmPackB(N, K, reinterpret_cast<const int8_t*>(b_data), scales.data(), biases.data(),
packed_b_.get());

bool share_prepacked_weights = (prepacked_weights != nullptr);
if (share_prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size);
}
int GetBZeroPointIdx() const override {
return IN_B_ZERO_POINT;
}

is_packed = true;
}
return Status::OK();
int GetBiasIdx() const override {
return IN_BIAS;
}
#endif

Expand All @@ -303,14 +194,6 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {

protected:
int GetBIdx() const override { return IN_B; }

private:
// Indicates when MlasDynamicQGemmBatch() can be used
bool can_use_dynamic_quant_mlas_{false};
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
// Indicates that the biases are a constant input and thus already quantized / packed
bool dynamic_quant_mlas_bias_data_was_packed_{false};
#endif
};

class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase {
Expand Down Expand Up @@ -381,7 +264,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
}
}
// Guard against KleidiAI functions being called in non kleidi builds
// TODO: migrate to a suitable override function call for kleidi dynamic qgemm function calls
// migrate to a suitable override function call for kelidiai dynamic qgemm function calls(TODO)
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
else {
MatMulComputeHelper helper;
Expand All @@ -390,10 +273,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
// deleted during session init post prepacking
nullptr,
nullptr));

// allocate the kernel’s output tensor from the execution context
Tensor* y = ctx->Output(OUT_Y, helper.OutputShape());

// Bail out early if the output is going to be empty
// Bail out early if any dimension is 0, the product (and hence the total number of elements) is 0
if (y->Shape().Size() == 0)
return Status::OK();

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ namespace ArmKleidiAI {

// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
inline const bool UseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME();

// Buffer packing routines.
//
Expand Down
168 changes: 141 additions & 27 deletions onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"

#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"

#include "mlasi_kleidiai.h"

// Thread-local reusable buffers to reduce allocation overhead across tiles.
struct KaiTlsBuffersQgemm {
std::vector<std::byte> lhs_packed;
std::vector<const std::byte*> lhs_base_table;
};
static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm;

//Matmul with float output of dynamic quantized A and symmetric quantized B.

size_t
Expand Down Expand Up @@ -80,42 +88,148 @@ MLASCALL
ArmKleidiAI::MlasDynamicQGemmBatch(
const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams,
const size_t BatchN,
const size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
) {
for (auto b = BatchN; b > 0; --b,++DataParams) {
auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();

const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();

//TODO enable multi-threading for lhs packing and matmul
MLAS_UNREFERENCED_PARAMETER(ThreadPool);
size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();

//Dynamic Quantize A - lhs
auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr);
std::byte* lhs = nullptr;
std::unique_ptr<std::byte[]> fallback;
if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 ) {
return;
}

//We are required to enforce errors when we reach this stage as we will not be able
//to reverse the packing decision that was made for RHS.

ORT_ENFORCE(DataParams != nullptr, "Dynamic QGEMM requires valid DataParams.");
ORT_ENFORCE(Shape.K > 0, "Dynamic QGEMM requires Shape.K to be non-zero.");

for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) {
const auto& params = DataParams[batch_idx];
ORT_ENFORCE(params.A != nullptr, "Dynamic QGEMM requires non-null A pointer for batch ", batch_idx);
ORT_ENFORCE(params.C != nullptr, "Dynamic QGEMM requires non-null C pointer for batch ", batch_idx);
ORT_ENFORCE(params.PackedB != nullptr, "Dynamic QGEMM requires non-null PackedB pointer for batch ", batch_idx);
const size_t lda = params.lda != 0 ? params.lda : Shape.K;
const size_t ldc = params.ldc != 0 ? params.ldc : Shape.N;
ORT_ENFORCE(lda >= Shape.K, "lda (", lda, ") must be >= Shape.K (", Shape.K, ") for batch ", batch_idx);
ORT_ENFORCE(ldc >= Shape.N, "ldc (", ldc, ") must be >= Shape.N (", Shape.N, ") for batch ", batch_idx);
}

//Dynamic Quantize A - lhs
const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr);
std::byte* LhsPackedData = nullptr;

if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) {
lhs = static_cast<std::byte*>(DataParams->Workspace);
if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) {

g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize);
}
g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just do the resizing directly instead of reserve + resize ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, reserve() + resize() or using only resize() cases both end up with one allocation + one initialisation. But somehow there is a very very little performance difference in the case allocation and initialisation separated or done at once with resize(). (after: is the case reserve() calls removed and only resize() is used.)
ort_ops_compare_2_thread_before_2025-10-29_13-08-56_vs_2_thread_after_2025-10-29_13-32-05

LhsPackedData = g_kai_tls_qgemm.lhs_packed.data();

//Per-batch table of lhs
if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) {

g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize);
}
g_kai_tls_qgemm.lhs_base_table.resize(BatchSize);
// Capture the shared batch table pointer so worker threads use the same backing storage.
const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data();
// B batches require no packing
// We have already decided the matmul variant we are using, before having values for M,N,K
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {

std::byte* lhs = nullptr;
if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) {
lhs = static_cast<std::byte*>(DataParams[batch_idx].Workspace);
} else {
fallback = std::make_unique<std::byte[]>(lhs_size);
lhs = fallback.get();
lhs = &(LhsPackedData[LhsPackedStride * batch_idx]);
}

KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32"
<< " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0");
kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A,
Shape.K*sizeof(float), lhs);

KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa");
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(
Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB,
DataParams->C,
Shape.N * sizeof(float),
sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs);
tls_lhs_base[batch_idx] = lhs;
});

// tile iteration dimensions
std::array<size_t, 3> dim;
dim[0] = BatchSize; // B
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there room for optimization in any of the logic below if BatchSize == 1 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I understood this clearly. BatchSize, dim[0], is contributes the multiplication on RequiredTiles. In later lines all the calculation goes over other dimensions not the BatchSize. I would argue the cost of the multiply of the other two dim by 1 is negligible and changing the code to treat this differently ,e.g checking against the BatchSize ==1, would complicate the code for no substantive gain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks

dim[1] = MlasDivRoundup(Shape.M, m_step); // M
dim[2] = MlasDivRoundup(Shape.N, n_step); // N

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there room for tuning this heuristic ? What if dim[0] * dim[1] * dim[2] is closer to 2 * Thread_count ? In that case, does it make sense to keep the required tiles as is as the tail processing is quite less or does it make sense to process bigger tiles and keep tile count smaller ?

Copy link
Author

@melkap01-Arm melkap01-Arm Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this line minimum of the dim[0] * dim[1] * dim[2] vs Thread_count is taken account in order to keep as is or enlarge the tile size accordingly. If we see the later lines code updates the m_step/n_step by the scale calculated in the middle. A rebalancing work going on here in order to minimise the kernel call. For example m_step/n_step = 16/64 initially and after the scaling according to the required tiles new m_step/n_step = 16/256 for a C matrix 1x512 when #threads=2 and becomes 16/192 when #threads=3 and 16/128 when #threads=6...
I believe this logic here both minimises the kernel call & leaves no room for tail processing. I am sorry if I didn't clearly understand the question but I feel like this logic is better for reducing tail processing as it tries to fit the tiles into the C tensor cleanly. Please highlight any point if this does not answer your question.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, thanks


// scale required tiles over available tile processors
dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]);
dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]);

// compute new step sizes
m_step *= MlasDivRoundup(MlasDivRoundup(Shape.M, dim[1]), m_step);
n_step *= MlasDivRoundup(MlasDivRoundup(Shape.N, dim[2]), n_step);

// update tile iterations
dim[1] = MlasDivRoundup(Shape.M, m_step);
dim[2] = MlasDivRoundup(Shape.N, n_step);

MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) {

// compute B,M,N index from iteration index
ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];

// Get rhs tile, B
const size_t rhs_packed_offset =
UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(NIdx * n_step, Shape.K)
: kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(NIdx * n_step, Shape.K);

const std::byte* B_base = reinterpret_cast<const std::byte*>(DataParams[BIdx].PackedB);
auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset);

// Get lhs tile, A
const size_t lhs_packed_offset =
UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(MIdx * m_step, Shape.K)
: kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(MIdx * m_step, Shape.K);

const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace;
auto ATile = reinterpret_cast<const std::byte*>(A_base + lhs_packed_offset);

auto TileSizeM = (MIdx + 1) * m_step > Shape.M ? (Shape.M - MIdx * m_step) : m_step;
auto TileSizeN = (NIdx + 1) * n_step > Shape.N ? (Shape.N - NIdx * n_step) : n_step;

float* dst_tile = reinterpret_cast<float*>(
reinterpret_cast<std::byte*>(DataParams[BIdx].C) +
MIdx * m_step * DataParams[BIdx].ldc * sizeof(float) +
NIdx * n_step * sizeof(float)
);
}

if (UseSME2) {
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
dst_tile,
DataParams[BIdx].ldc * sizeof(float),
sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
}
else {
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
dst_tile,
DataParams[BIdx].ldc * sizeof(float),
sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
}
});
}
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ MLASCALL
MlasIsDynamicQGemmAvailable()
{
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
return ArmKleidiAI::UseSME2;
return (ArmKleidiAI::UseSME || ArmKleidiAI::UseSME2);
#else
return false;
#endif
Expand Down
Loading
Loading