-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Implement multithreading in qgemm_kleidi #26301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement multithreading in qgemm_kleidi #26301
Conversation
|
@microsoft-github-policy-service agree company="Arm" |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
| cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 | ||
| dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a | ||
| kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a | ||
| kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this update in the KAI version from 1.10 to 1.15, can SME/SME2 detection be enabled on Windows too to leverage the kernels ?
https://github.com/microsoft/onnxruntime/pull/25187/files#r2223006773
https://github.com/microsoft/onnxruntime/pull/25760/files#r2325260570
|
Can we get workflows ran please |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
|
||
| g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); | ||
| } | ||
| g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we just do the resizing directly instead of reserve + resize ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, reserve() + resize() or using only resize() cases both end up with one allocation + one initialisation. But somehow there is a very very little performance difference in the case allocation and initialisation separated or done at once with resize(). (after: is the case reserve() calls removed and only resize() is used.)

| g_kai_tls_qgemm.output_tile.reserve(tile_elems); | ||
| } | ||
| // resize the tile to the required size (doesn't effect memory) | ||
| g_kai_tls_qgemm.output_tile.resize(tile_elems); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto - Is Reserve + Resize necessary ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
| // Thread-local reusable buffers to reduce allocation overhead across tiles. | ||
| struct KaiTlsBuffersQgemm { | ||
| std::vector<float> output_tile; | ||
| std::vector<float> bias_zero; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is bias_zero used somewhere ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in the new commit
| g_kai_tls_qgemm.output_tile.resize(tile_elems); | ||
| } | ||
| float* temp_tile = g_kai_tls_qgemm.output_tile.data(); | ||
| std::fill_n(temp_tile, TileSizeM * TileSizeN, 0.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this buffer zeroing absolutely needed (i.e.) Does the micro-kernel accumulate into the existing contents ?
Is there a concept of dis-reagrding existing contents in the output buffer in the micro-kernel's interface ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the fill_n, the kernel handles zeroing of the tile
| LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); | ||
|
|
||
| //Per-batch table of lhs | ||
| std::vector<const std::byte*> LhsBase(BatchSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought - Can this vector containing the per-batch address be moved into the KaiTlsBuffersQgemm struct and be re-sized when it's size is less than the BatchSize ?
The pro of that approach:
- We generally expect the BatchSize to be stable across runs and that will mean we can do away with the dynamic memory allocation latency variance that comes with using std::vector
The con of that approach:
- The size of that caching vector will be bound by the highest batch size that the kernel will encounter.
Given that the batch sizes are generally stable across different runs, I am thinking the pro might outweight the con ?
What are your thoughts on this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an idea worth to try and measure the impact.
I implemented it and the results with single thread:
After: Lhsbase is moved inside the TLS structure. Before: LhaBase is a local buffer shared with the threads.
here is the implementation:
//Per-batch table of lhs
if (g_kai_tls_qgemm.LhsBase.capacity() < BatchSize) {
g_kai_tls_qgemm.LhsBase.reserve(BatchSize);
}
g_kai_tls_qgemm.LhsBase.resize(BatchSize);
// Capture the shared batch table pointer so worker threads use the same backing storage.
const std::byte** tls_lhs_base = g_kai_tls_qgemm.LhsBase.data();
// B batches require no packing
⋮
kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs);
tls_lhs_base[batch_idx] = lhs;
});
⋮
const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace;
auto ATile = reinterpret_cast<const std::byte*>(A_base + lhs_packed_offset);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect perf-wise there isn't much difference but it is coming from a performance variance POV. If we performed dynamic memory allocations on every Run(), I suspect we may see some latency variance. I was just wonderinf if this can be avoided as in most cases, usually the Gemm problem shapes stay the same across invocations. Let us dynamically resize only when we encounter a change of shape (batch size). Hope the motivation of the comment is clear now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Motivation behind the comment is clear, if we expect generally stable batches, reusing its capacity across calls is making sense. If the performance results also acceptable we are all good with this idea. Please find the implementation in the latest commit.
|
|
||
| if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { | ||
| lhs = static_cast<std::byte*>(DataParams->Workspace); | ||
| if (Shape.M == 0 || Shape.N == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a Shape.K check for completeness ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in the newest commit.
|
General sanity check question: Are there enough tests that trigger all the nuances of the multi-threaded implementation - Are there enough tests with multiple batch sizes, M, and N dimensions that exercise all aspects of the multi-threaded implementation ? |
| return; | ||
| } | ||
| if ((Shape.M < m_step || Shape.N < n_step) && !DataParams->PackedB) { | ||
| // Fallback to MLAS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no fallback implementation of MlasDynamicQGemmBatch().
onnxruntime/onnxruntime/core/mlas/lib/qgemm.cpp
Lines 212 to 222 in 0f6cffc
| #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) | |
| //No fallback and putting in guards | |
| if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ | |
| ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); | |
| } | |
| #endif | |
| MLAS_UNREFERENCED_PARAMETER(Shape); | |
| MLAS_UNREFERENCED_PARAMETER(DataParams); | |
| MLAS_UNREFERENCED_PARAMETER(BatchN); | |
| MLAS_UNREFERENCED_PARAMETER(ThreadPool); |
if we get to this point, the computation should happen or (maybe less preferably) it should be a hard error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will investigate the fallback case further and try to provide better implementation.
Until then, would like to get your opinion on using ORT_ENFORCE
ORT_ENFORCE(false, "ArmKleidiAI::MlasDynamicQGemmBatch(): unsupported small-shape case (M < m_step or N < n_step)");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we instead implement @edgchen1's suggestion in the other PR: #26302 (comment) to have a universal check that can be used in all places to check if MLAS supports QGemm for that problem shape, platform, etc. ?
Also since we have a check on the M dimension, this might need some thinking - In the current setup, we turn off MLAS usage for QGemm in PrePack() if we don't detect SME or the weight's shape don't match requirements in PrePack(). See here and here. The M dimension won't be known in PrePack().
Just curious - what would happen if the M was < m_step ? Would there be a crash or would the perf be sub-optimal ? If so, we need to add a runtime check in the CPU kernel's Run() function which means we may need to perform pre-packing for both KAI and the "regular" path. See here.
|
|
||
| // Final output tile pointer | ||
| float* dst_tile = reinterpret_cast<float*>(CTile); | ||
| std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the benefit of writing to a temporary buffer (temp_tile) and then copying it to dst_tile instead of directly writing to dst_tile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea behind it was making the arithmetics on the temporary tile to be error prone as it was implemented on the sgemms. But I see making the calculations on the destination and writing directly is lowering the complexity.
instead of having the result in each TLS and copying to the destination tile, destination tile can have the result directly.
Measuring the impact :
single thread:
|
Will trigger CI once you push commits addressing the PR feedback (right now I only see a rebase). Thanks. |
We checked the existing tests for qgemm. In current implementation tests are supported for thread pool = null. We created a follow up ticket for test coverage. |
If all the tests are with ThreadPool == null, does that mean the new threadpool based parallel code path(s) are not exercised ? |
It means it was not exercised on the onnxruntime_mlas_test run, but it is on the onnxruntime_perf_test. However, unit tests for the multithreaded code added now, in the latest commit. Both cases can use multiple threads in the latest situation. |
Signed-off-by: melkap01 <[email protected]>
unused variable removed, unnecessary temp_tile use and copy removed, K==0 case checked Signed-off-by: melkap01 <[email protected]>
Signed-off-by: melkap01 <[email protected]>
Signed-off-by: melkap01 <[email protected]>
| // Indicates that the biases are a constant input and thus already quantized / packed | ||
| bool dynamic_quant_mlas_bias_data_was_packed_{false}; | ||
| #endif | ||
| // Flag storage is handled by MatMulIntegerBase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment on line 200 is dangling ("Indicates when....") - I guess it is no longer relevant given that the flags have moved....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be addressed in the new commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| dim[2] = MlasDivRoundup(Shape.N, n_step); // N | ||
|
|
||
| // Minimize the kernel call count for the number of available threads | ||
| auto RequiredTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there room for tuning this heuristic ? What if dim[0] * dim[1] * dim[2] is closer to 2 * Thread_count ? In that case, does it make sense to keep the required tiles as is as the tail processing is quite less or does it make sense to process bigger tiles and keep tile count smaller ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this line minimum of the dim[0] * dim[1] * dim[2] vs Thread_count is taken account in order to keep as is or enlarge the tile size accordingly. If we see the later lines code updates the m_step/n_step by the scale calculated in the middle. A rebalancing work going on here in order to minimise the kernel call. For example m_step/n_step = 16/64 initially and after the scaling according to the required tiles new m_step/n_step = 16/256 for a C matrix 1x512 when #threads=2 and becomes 16/192 when #threads=3 and 16/128 when #threads=6...
I believe this logic here both minimises the kernel call & leaves no room for tail processing. I am sorry if I didn't clearly understand the question but I feel like this logic is better for reducing tail processing as it tries to fit the tiles into the C tensor cleanly. Please highlight any point if this does not answer your question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, thanks
|
|
||
| // tile iteration dimensions | ||
| std::array<size_t, 3> dim; | ||
| dim[0] = BatchSize; // B |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there room for optimization in any of the logic below if BatchSize == 1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if I understood this clearly. BatchSize, dim[0], is contributes the multiplication on RequiredTiles. In later lines all the calculation goes over other dimensions not the BatchSize. I would argue the cost of the multiply of the other two dim by 1 is negligible and changing the code to treat this differently ,e.g checking against the BatchSize ==1, would complicate the code for no substantive gain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks
| is_packed = false; | ||
|
|
||
| // only pack Matrix B | ||
| // only pack Matrix B++ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Is ++ a typo ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it is , will be addressed in the new commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| const size_t packed_b_size = MlasDynamicQgemmPackBSize(ctx.N, ctx.K); | ||
| if (packed_b_size == 0) { | ||
| can_use_dynamic_quant_mlas_ = false; | ||
| return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this return false if can_use_dynamic_quant_mlas_ = false ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be addressed in the new commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrected
| bool IsBShapeSupportedForDynamicQuant(const TensorShape& tensor_shape) { | ||
| b_shape_ = tensor_shape; | ||
| if (b_shape_.NumDimensions() < 2) { | ||
| return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Low priority question: Can 1-D shapes be promoted to 2-D shapes by pre-pending or appending 1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is implemented in the latest commit.
| std::optional<Tensor> transposed_buffer; | ||
| }; | ||
|
|
||
| bool TryKleidiaiDynamicPrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A brief description of what each of the following helper methods do and look for and when it returns true/false will help the reader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
provided
| public: | ||
| void Test(size_t M, size_t N, size_t K, size_t BatchSize) { | ||
| // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. | ||
| if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use the MLAS APi to check if dynamic Q Gemm functionality is available ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| public: | ||
| void Test(size_t M, size_t N, size_t K, size_t BatchSize) { | ||
| // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. | ||
| if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
-MlasIsDynamicQGemmAvailable() used instead of CPUIDInfo::GetCPUIDInfo().HasArm_SME() Signed-off-by: melkap01 <[email protected]>
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Signed-off-by: melkap01 <[email protected]>
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Signed-off-by: melkap01 <[email protected]>
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |


Key changes
This PR makes changes to improve the performance on Dynamic Qgemms by implementing tiling and threading across operations.
The changes introduce thread local buffers for reusing memory during inference. And utilizes those in Dynamic Quantised Matmul operations using Kleidiai kernels.
And updating KleidiAI version to 1.15.0
Example performance
single thread :

2 threads :
