-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Implement multithreading in qgemm_kleidi #26301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
517e166
bd05fce
75fee7a
e53e67b
c2428fb
0479aea
9ef3b4c
32bf43c
b6ff3be
7ed0e5c
76ba64f
6bac9a5
f1605e5
0c3748b
99fe8c5
47e4c92
fb8eefb
d9a26bf
cd80e56
6356e68
50dddaf
017a425
2ad388c
8dc8bc3
e000f04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,18 @@ | |
| #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" | ||
|
|
||
| #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" | ||
| #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h" | ||
| #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" | ||
|
|
||
| #include "mlasi_kleidiai.h" | ||
|
|
||
| // Thread-local reusable buffers to reduce allocation overhead across tiles. | ||
| struct KaiTlsBuffersQgemm { | ||
| std::vector<std::byte> lhs_packed; | ||
| std::vector<const std::byte*> lhs_base_table; | ||
| }; | ||
| static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm; | ||
|
|
||
| //Matmul with float output of dynamic quantized A and symmetric quantized B. | ||
|
|
||
| size_t | ||
|
|
@@ -80,42 +88,148 @@ MLASCALL | |
| ArmKleidiAI::MlasDynamicQGemmBatch( | ||
| const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, | ||
| const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, | ||
| const size_t BatchN, | ||
| const size_t BatchSize, | ||
| MLAS_THREADPOOL* ThreadPool | ||
| ) { | ||
| for (auto b = BatchN; b > 0; --b,++DataParams) { | ||
| auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); | ||
| auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); | ||
| auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); | ||
|
|
||
| const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||
| : kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||
| const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||
| : kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||
| const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||
| : kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||
|
|
||
| //TODO enable multi-threading for lhs packing and matmul | ||
| MLAS_UNREFERENCED_PARAMETER(ThreadPool); | ||
| size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||
| : kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||
| size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() | ||
| : kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); | ||
|
|
||
| //Dynamic Quantize A - lhs | ||
| auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); | ||
| std::byte* lhs = nullptr; | ||
| std::unique_ptr<std::byte[]> fallback; | ||
| if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 ) { | ||
| return; | ||
| } | ||
|
|
||
| //We are required to enforce errors when we reach this stage as we will not be able | ||
| //to reverse the packing decision that was made for RHS. | ||
|
|
||
| ORT_ENFORCE(DataParams != nullptr, "Dynamic QGEMM requires valid DataParams."); | ||
| ORT_ENFORCE(Shape.K > 0, "Dynamic QGEMM requires Shape.K to be non-zero."); | ||
|
|
||
| for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) { | ||
| const auto& params = DataParams[batch_idx]; | ||
| ORT_ENFORCE(params.A != nullptr, "Dynamic QGEMM requires non-null A pointer for batch ", batch_idx); | ||
| ORT_ENFORCE(params.C != nullptr, "Dynamic QGEMM requires non-null C pointer for batch ", batch_idx); | ||
| ORT_ENFORCE(params.PackedB != nullptr, "Dynamic QGEMM requires non-null PackedB pointer for batch ", batch_idx); | ||
| const size_t lda = params.lda != 0 ? params.lda : Shape.K; | ||
| const size_t ldc = params.ldc != 0 ? params.ldc : Shape.N; | ||
| ORT_ENFORCE(lda >= Shape.K, "lda (", lda, ") must be >= Shape.K (", Shape.K, ") for batch ", batch_idx); | ||
| ORT_ENFORCE(ldc >= Shape.N, "ldc (", ldc, ") must be >= Shape.N (", Shape.N, ") for batch ", batch_idx); | ||
| } | ||
|
|
||
| //Dynamic Quantize A - lhs | ||
| const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); | ||
| std::byte* LhsPackedData = nullptr; | ||
|
|
||
| if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { | ||
| lhs = static_cast<std::byte*>(DataParams->Workspace); | ||
| if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) { | ||
|
|
||
| g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); | ||
| } | ||
| g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); | ||
| LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); | ||
|
|
||
| //Per-batch table of lhs | ||
| if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) { | ||
|
|
||
| g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize); | ||
| } | ||
| g_kai_tls_qgemm.lhs_base_table.resize(BatchSize); | ||
| // Capture the shared batch table pointer so worker threads use the same backing storage. | ||
| const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data(); | ||
| // B batches require no packing | ||
| // We have already decided the matmul variant we are using, before having values for M,N,K | ||
| MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { | ||
|
|
||
| std::byte* lhs = nullptr; | ||
| if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) { | ||
| lhs = static_cast<std::byte*>(DataParams[batch_idx].Workspace); | ||
| } else { | ||
| fallback = std::make_unique<std::byte[]>(lhs_size); | ||
| lhs = fallback.get(); | ||
| lhs = &(LhsPackedData[LhsPackedStride * batch_idx]); | ||
| } | ||
|
|
||
| KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32" | ||
| << " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0"); | ||
| kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, | ||
| Shape.K*sizeof(float), lhs); | ||
|
|
||
| KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa"); | ||
| kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( | ||
| Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, | ||
| DataParams->C, | ||
| Shape.N * sizeof(float), | ||
| sizeof(float), | ||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||
| kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs); | ||
| tls_lhs_base[batch_idx] = lhs; | ||
| }); | ||
|
|
||
| // tile iteration dimensions | ||
| std::array<size_t, 3> dim; | ||
| dim[0] = BatchSize; // B | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there room for optimization in any of the logic below if BatchSize == 1 ?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if I understood this clearly. BatchSize, dim[0], is contributes the multiplication on RequiredTiles. In later lines all the calculation goes over other dimensions not the BatchSize. I would argue the cost of the multiply of the other two dim by 1 is negligible and changing the code to treat this differently ,e.g checking against the BatchSize ==1, would complicate the code for no substantive gain.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, thanks |
||
| dim[1] = MlasDivRoundup(Shape.M, m_step); // M | ||
| dim[2] = MlasDivRoundup(Shape.N, n_step); // N | ||
|
|
||
| // Minimize the kernel call count for the number of available threads | ||
| auto RequiredTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there room for tuning this heuristic ? What if dim[0] * dim[1] * dim[2] is closer to 2 * Thread_count ? In that case, does it make sense to keep the required tiles as is as the tail processing is quite less or does it make sense to process bigger tiles and keep tile count smaller ?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this line minimum of the dim[0] * dim[1] * dim[2] vs Thread_count is taken account in order to keep as is or enlarge the tile size accordingly. If we see the later lines code updates the m_step/n_step by the scale calculated in the middle. A rebalancing work going on here in order to minimise the kernel call. For example m_step/n_step = 16/64 initially and after the scaling according to the required tiles new m_step/n_step = 16/256 for a C matrix 1x512 when #threads=2 and becomes 16/192 when #threads=3 and 16/128 when #threads=6...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough, thanks |
||
|
|
||
| // scale required tiles over available tile processors | ||
| dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); | ||
| dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); | ||
|
|
||
| // compute new step sizes | ||
| m_step *= MlasDivRoundup(MlasDivRoundup(Shape.M, dim[1]), m_step); | ||
| n_step *= MlasDivRoundup(MlasDivRoundup(Shape.N, dim[2]), n_step); | ||
|
|
||
| // update tile iterations | ||
| dim[1] = MlasDivRoundup(Shape.M, m_step); | ||
| dim[2] = MlasDivRoundup(Shape.N, n_step); | ||
|
|
||
| MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { | ||
|
|
||
| // compute B,M,N index from iteration index | ||
| ptrdiff_t BIdx = tid / (dim[1] * dim[2]); | ||
| ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; | ||
| ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; | ||
|
|
||
| // Get rhs tile, B | ||
| const size_t rhs_packed_offset = | ||
| UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(NIdx * n_step, Shape.K) | ||
| : kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(NIdx * n_step, Shape.K); | ||
|
|
||
| const std::byte* B_base = reinterpret_cast<const std::byte*>(DataParams[BIdx].PackedB); | ||
| auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset); | ||
|
|
||
| // Get lhs tile, A | ||
| const size_t lhs_packed_offset = | ||
| UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(MIdx * m_step, Shape.K) | ||
| : kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(MIdx * m_step, Shape.K); | ||
|
|
||
| const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace; | ||
| auto ATile = reinterpret_cast<const std::byte*>(A_base + lhs_packed_offset); | ||
|
|
||
| auto TileSizeM = (MIdx + 1) * m_step > Shape.M ? (Shape.M - MIdx * m_step) : m_step; | ||
| auto TileSizeN = (NIdx + 1) * n_step > Shape.N ? (Shape.N - NIdx * n_step) : n_step; | ||
|
|
||
| float* dst_tile = reinterpret_cast<float*>( | ||
| reinterpret_cast<std::byte*>(DataParams[BIdx].C) + | ||
| MIdx * m_step * DataParams[BIdx].ldc * sizeof(float) + | ||
| NIdx * n_step * sizeof(float) | ||
| ); | ||
| } | ||
|
|
||
| if (UseSME2) { | ||
| kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( | ||
| TileSizeM, TileSizeN, Shape.K, ATile, BTile, | ||
| dst_tile, | ||
| DataParams[BIdx].ldc * sizeof(float), | ||
| sizeof(float), | ||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||
| ); | ||
| } | ||
| else { | ||
| kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa( | ||
| TileSizeM, TileSizeN, Shape.K, ATile, BTile, | ||
| dst_tile, | ||
| DataParams[BIdx].ldc * sizeof(float), | ||
| sizeof(float), | ||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||
| ); | ||
| } | ||
| }); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we just do the resizing directly instead of reserve + resize ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, reserve() + resize() or using only resize() cases both end up with one allocation + one initialisation. But somehow there is a very very little performance difference in the case allocation and initialisation separated or done at once with resize(). (after: is the case reserve() calls removed and only resize() is used.)
