diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 2bba0adcd987c..b929fc2de4bbd 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -164,132 +164,23 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { Status Compute(OpKernelContext* context) const override; #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - /*out*/ bool& is_packed, - /*out*/ PrePackedWeights* prepacked_weights) override { - // only pack Matrix B - if (input_idx == GetBIdx()) { - const Tensor* b_zp_constant_tensor{nullptr}; - bool b_quantization_might_be_asymmetric = false; - - const OrtValue* b_zp; - if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { - b_zp_constant_tensor = &b_zp->Get(); - } - - // MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros - // or not provided. - if (b_zp_constant_tensor != nullptr) { - // B zero point is constant. Check if it is all zeros. - assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); - const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); - const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); - b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes, - [](std::byte v) { return v != std::byte{0}; }); - } else { - // B zero point input is not constant. If it exists, we can't assume symmetric quantization. - const auto input_defs = Info().node().InputDefs(); - const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists(); - b_quantization_might_be_asymmetric = b_zp_input_exists; - } - - // MlasDynamicQgemm requires scale data to be available at packing stage - const Tensor* b_scale_tensor = nullptr; - const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); - - can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); - - // Kleidi dynamic path requires strictly positive, finite scales. - // Disable if any invalid scale is detected. - if (can_use_dynamic_quant_mlas_) { - const auto bs = b_scale_tensor->DataAsSpan(); - const bool has_invalid = - std::any_of(bs.begin(), bs.end(), - [](float s) { return !std::isfinite(s) || s <= 0.0f; }); - - if (has_invalid) { - can_use_dynamic_quant_mlas_ = false; - } - } - - if (!MlasIsDynamicQGemmAvailable()) { - can_use_dynamic_quant_mlas_ = false; - } - - // Only handle the common case of a 2D weight matrix. Additional matrices - // could be handled by stacking the packed buffers. - b_shape_ = tensor.Shape(); - if (b_shape_.NumDimensions() >= 2) { - for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) { - if (b_shape_[i] != 1) { - can_use_dynamic_quant_mlas_ = false; - break; - } - } - } else { - can_use_dynamic_quant_mlas_ = false; - } - - // Can we use the mlas dynamic Q gemm interface supported with float output ? - if (!can_use_dynamic_quant_mlas_) { - // default to piece wise mlas interface with separate int matmul, quantize and float conversion - return MatMulIntegerToFloatBase::PrePack(tensor, input_idx, alloc, is_packed, prepacked_weights); - } - is_packed = false; - - // Default to all zeros for bias - const Tensor* bias_tensor{nullptr}; - const OrtValue* bias; - if (Info().TryGetConstantInput(IN_BIAS, &bias)) { - bias_tensor = &bias->Get(); - dynamic_quant_mlas_bias_data_was_packed_ = true; - } - size_t K = static_cast(b_shape_[0]); - size_t N = static_cast(b_shape_[1]); - - const auto* b_data = static_cast(tensor.DataRaw()); - - std::optional b_trans_buffer; - if (IsBTransposed()) { - std::swap(K, N); - b_data = quantization::TransPoseInputData(b_data, b_trans_buffer, alloc, N, K); - } + bool SupportsKleidiaiDynamicQuant() const override { + if (!MlasIsDynamicQGemmAvailable()) { + return false; + } + return true; + } - const size_t packed_b_size = MlasDynamicQgemmPackBSize(N, K); - if (packed_b_size == 0) { - return Status::OK(); - } + int GetBScaleIdx() const override { + return IN_B_SCALE; + } - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); - // Initialize memory to 0 as there could be some padding associated with pre-packed - // buffer memory and we do not want it uninitialized and generate different hashes - // if and when we try to cache this pre-packed buffer for sharing between sessions. - memset(packed_b_.get(), 0, packed_b_size); - - const auto scales = static_cast(b_scale_tensor->Shape().Size()) == N ? std::vector(&b_scale_tensor->Data()[0], - &b_scale_tensor->Data()[N]) - : - // Broadcast matrix scale to all channels - std::vector(N, b_scale_tensor->Data()[0]); - - const auto biases = bias_tensor != nullptr ? std::vector(&bias_tensor->Data()[0], - &bias_tensor->Data()[N]) - : - // Broadcast zero to all channels - no bias data is available - std::vector(N, 0.f); - - MlasDynamicQgemmPackB(N, K, reinterpret_cast(b_data), scales.data(), biases.data(), - packed_b_.get()); - - bool share_prepacked_weights = (prepacked_weights != nullptr); - if (share_prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size); - } + int GetBZeroPointIdx() const override { + return IN_B_ZERO_POINT; + } - is_packed = true; - } - return Status::OK(); + int GetBiasIdx() const override { + return IN_BIAS; } #endif @@ -303,14 +194,6 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { protected: int GetBIdx() const override { return IN_B; } - - private: - // Indicates when MlasDynamicQGemmBatch() can be used - bool can_use_dynamic_quant_mlas_{false}; -#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) - // Indicates that the biases are a constant input and thus already quantized / packed - bool dynamic_quant_mlas_bias_data_was_packed_{false}; -#endif }; class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase { @@ -381,7 +264,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { } } // Guard against KleidiAI functions being called in non kleidi builds - // TODO: migrate to a suitable override function call for kleidi dynamic qgemm function calls + // migrate to a suitable override function call for kelidiai dynamic qgemm function calls(TODO) #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) else { MatMulComputeHelper helper; @@ -390,10 +273,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { // deleted during session init post prepacking nullptr, nullptr)); - + // allocate the kernel’s output tensor from the execution context Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); - // Bail out early if the output is going to be empty + // Bail out early if any dimension is 0, the product (and hence the total number of elements) is 0 if (y->Shape().Size() == 0) return Status::OK(); diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index ca81b9fa426ee..d652989a610e0 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -53,6 +53,7 @@ namespace ArmKleidiAI { // By default we should try for SME2 first before falling back to SME. inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); +inline const bool UseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME(); // Buffer packing routines. // diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp index 1d682b372e2f5..4164a0134a980 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -11,10 +11,18 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" #include "mlasi_kleidiai.h" +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffersQgemm { + std::vector lhs_packed; + std::vector lhs_base_table; +}; +static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm; + //Matmul with float output of dynamic quantized A and symmetric quantized B. size_t @@ -80,42 +88,148 @@ MLASCALL ArmKleidiAI::MlasDynamicQGemmBatch( const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, - const size_t BatchN, + const size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) { - for (auto b = BatchN; b > 0; --b,++DataParams) { - auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); - //TODO enable multi-threading for lhs packing and matmul - MLAS_UNREFERENCED_PARAMETER(ThreadPool); + size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); + size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); - //Dynamic Quantize A - lhs - auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); - std::byte* lhs = nullptr; - std::unique_ptr fallback; + if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 ) { + return; + } + + //We are required to enforce errors when we reach this stage as we will not be able + //to reverse the packing decision that was made for RHS. + + ORT_ENFORCE(DataParams != nullptr, "Dynamic QGEMM requires valid DataParams."); + ORT_ENFORCE(Shape.K > 0, "Dynamic QGEMM requires Shape.K to be non-zero."); + + for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) { + const auto& params = DataParams[batch_idx]; + ORT_ENFORCE(params.A != nullptr, "Dynamic QGEMM requires non-null A pointer for batch ", batch_idx); + ORT_ENFORCE(params.C != nullptr, "Dynamic QGEMM requires non-null C pointer for batch ", batch_idx); + ORT_ENFORCE(params.PackedB != nullptr, "Dynamic QGEMM requires non-null PackedB pointer for batch ", batch_idx); + const size_t lda = params.lda != 0 ? params.lda : Shape.K; + const size_t ldc = params.ldc != 0 ? params.ldc : Shape.N; + ORT_ENFORCE(lda >= Shape.K, "lda (", lda, ") must be >= Shape.K (", Shape.K, ") for batch ", batch_idx); + ORT_ENFORCE(ldc >= Shape.N, "ldc (", ldc, ") must be >= Shape.N (", Shape.N, ") for batch ", batch_idx); + } + + //Dynamic Quantize A - lhs + const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); + std::byte* LhsPackedData = nullptr; - if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { - lhs = static_cast(DataParams->Workspace); + if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) { + + g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); + } + g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); + LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); + + //Per-batch table of lhs + if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) { + + g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize); + } + g_kai_tls_qgemm.lhs_base_table.resize(BatchSize); + // Capture the shared batch table pointer so worker threads use the same backing storage. + const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data(); + // B batches require no packing + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + + std::byte* lhs = nullptr; + if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) { + lhs = static_cast(DataParams[batch_idx].Workspace); } else { - fallback = std::make_unique(lhs_size); - lhs = fallback.get(); + lhs = &(LhsPackedData[LhsPackedStride * batch_idx]); } - KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32" << " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0"); - kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, - Shape.K*sizeof(float), lhs); - - KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa"); - kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( - Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, - DataParams->C, - Shape.N * sizeof(float), - sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() + kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs); + tls_lhs_base[batch_idx] = lhs; + }); + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(Shape.M, m_step); // M + dim[2] = MlasDivRoundup(Shape.N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(Shape.M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(Shape.N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(Shape.M, m_step); + dim[2] = MlasDivRoundup(Shape.N, n_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(NIdx * n_step, Shape.K) + : kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(NIdx * n_step, Shape.K); + + const std::byte* B_base = reinterpret_cast(DataParams[BIdx].PackedB); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset = + UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(MIdx * m_step, Shape.K) + : kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(MIdx * m_step, Shape.K); + + const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > Shape.M ? (Shape.M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > Shape.N ? (Shape.N - NIdx * n_step) : n_step; + + float* dst_tile = reinterpret_cast( + reinterpret_cast(DataParams[BIdx].C) + + MIdx * m_step * DataParams[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) ); - } + + if (UseSME2) { + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( + TileSizeM, TileSizeN, Shape.K, ATile, BTile, + dst_tile, + DataParams[BIdx].ldc * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + else { + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa( + TileSizeM, TileSizeN, Shape.K, ATile, BTile, + dst_tile, + DataParams[BIdx].ldc * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + }); } diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 186dc81d7b7b7..4e523b47e3f44 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -206,7 +206,7 @@ MLASCALL MlasIsDynamicQGemmAvailable() { #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) - return ArmKleidiAI::UseSME2; + return (ArmKleidiAI::UseSME || ArmKleidiAI::UseSME2); #else return false; #endif diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index e26eae19b8fd4..b39d2271575f2 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -1,6 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include +#include +#include + +#include "core/common/cpuid_info.h" #include "core/framework/op_kernel.h" #include "core/mlas/inc/mlas.h" #include "core/providers/common.h" @@ -20,6 +27,11 @@ class MatMulIntegerBase : public OpKernel { // only pack Matrix B if (input_idx == GetBIdx()) { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (TryKleidiaiDynamicPrePack(tensor, input_idx, alloc, is_packed, prepacked_weights)) { + return Status::OK(); + } +#endif // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. b_shape_ = tensor.Shape(); @@ -89,6 +101,248 @@ class MatMulIntegerBase : public OpKernel { return false; } + virtual int GetBScaleIdx() const { + return -1; + } + + virtual int GetBZeroPointIdx() const { + return -1; + } + + virtual int GetBiasIdx() const { + return -1; + } + + virtual bool SupportsKleidiaiDynamicQuant() const { + return false; + } + + bool can_use_dynamic_quant_mlas_{false}; + +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + struct KleidiaiDynamicPackContext { + const Tensor* scale{nullptr}; + const Tensor* bias{nullptr}; + const uint8_t* b_data{nullptr}; + size_t K{0}; + size_t N{0}; + std::optional transposed_buffer; + }; + /* + Helper method to pre-pack Matrix B using Arm® KleidiAI™ packing if eligible. + + Returns false if KleidiAI dynamic qantization is not supported or the index of the input tensor is not input B's index. + If these checks passes, prepares a dynamic quantization pack content and calls PrepareKleidiaiDynamicPack for futher policies. + If those policies also satisfy, it calls the helper to execute the pre-packing in KleidiAI context. + Returns true of pre-packing was performed and false otherwise. + */ + bool TryKleidiaiDynamicPrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, + PrePackedWeights* prepacked_weights) { + if (!SupportsKleidiaiDynamicQuant() || input_idx != GetBIdx()) { + return false; + } + + KleidiaiDynamicPackContext ctx; + if (!PrepareKleidiaiDynamicPack(tensor, alloc, ctx)) { + return false; + } + + return ExecuteKleidiaiDynamicPack(ctx, alloc, is_packed, prepacked_weights); + } + /* + Helper method to determine if Arm® KleidiAI™ dynamic quantization pre-packing policies are satisfied. + + Checks for the presence of the constant input tensor B, symmetricity on the zero point and validity of the scales. + Also checks if the shape of the tensor B is supported by KleidiAI and if bias tensor is also a constant input. + Makes B transposition if necessary. + Sets can_use_dynamic_quant_mlas_ flag accordingly nd returns true if all policies are satisfied. + */ + bool PrepareKleidiaiDynamicPack(const Tensor& tensor, + AllocatorPtr alloc, + KleidiaiDynamicPackContext& ctx) { + can_use_dynamic_quant_mlas_ = false; + dynamic_quant_mlas_bias_data_was_packed_ = false; + + ctx.scale = GetConstantInputTensor(GetBScaleIdx()); + if (ctx.scale == nullptr) { + return false; + } + + if (!IsZeroPointSymmetric()) { + return false; + } + + if (!AreScalesValid(*ctx.scale)) { + can_use_dynamic_quant_mlas_ = false; + return false; + } + + if (!IsBShapeSupportedForDynamicQuant(tensor.Shape())) { + can_use_dynamic_quant_mlas_ = false; + return false; + } + + ctx.bias = GetConstantInputTensor(GetBiasIdx()); + if (ctx.bias != nullptr) { + dynamic_quant_mlas_bias_data_was_packed_ = true; + } + + ctx.K = static_cast(b_shape_[0]); + ctx.N = static_cast(b_shape_[1]); + ctx.b_data = static_cast(tensor.DataRaw()); + + if (IsBTransposed()) { + std::swap(ctx.K, ctx.N); + ctx.b_data = quantization::TransPoseInputData(ctx.b_data, ctx.transposed_buffer, alloc, ctx.N, ctx.K); + } + + can_use_dynamic_quant_mlas_ = true; + return true; + } + /* + Helper method to execute Arm® KleidiAI™ dynamic quantization pre-packing. + + If can_use_dynamic_quant_mlas_ flag was true from previous policy controls then it checks the packed + RHS matrix size in bytes and allocates the packed buffer. If the size is 0 returns false. + It then assigns the scale and bias data accordingly and calls the packing function. + It caches this pre-packed buffer as Mlas does. + */ + bool ExecuteKleidiaiDynamicPack(const KleidiaiDynamicPackContext& ctx, + AllocatorPtr alloc, + bool& is_packed, + PrePackedWeights* prepacked_weights) { + if (!can_use_dynamic_quant_mlas_) { + return false; + } + + is_packed = false; + + const size_t packed_b_size = MlasDynamicQgemmPackBSize(ctx.N, ctx.K); + if (packed_b_size == 0) { + can_use_dynamic_quant_mlas_ = false; + return false; + } + + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + memset(packed_b_.get(), 0, packed_b_size); + + const auto scales = static_cast(ctx.scale->Shape().Size()) == ctx.N + ? std::vector(&ctx.scale->Data()[0], + &ctx.scale->Data()[ctx.N]) + : std::vector(ctx.N, ctx.scale->Data()[0]); + + const auto biases = ctx.bias != nullptr + ? std::vector(&ctx.bias->Data()[0], + &ctx.bias->Data()[ctx.N]) + : std::vector(ctx.N, 0.f); + + MlasDynamicQgemmPackB(ctx.N, ctx.K, reinterpret_cast(ctx.b_data), + scales.data(), biases.data(), packed_b_.get()); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size); + } + + is_packed = true; + return true; + } + /* + Helper for checking the zero points tensor of the input. Arm® KleidiAI™ supports symmetric zero points. + + This helper method checks if zero point tensor , if its present in the inputs with its index, it checks it the data type aither uint8_t or int8_t. + It also checks if all the zero point values are zeros. If not sets the can_use_dynamic_quant_mlas_ flag to false. + If zero point tensor is not present, it sets the falg true as symmetric zero point is assumed. + Returns the flag. + */ + bool IsZeroPointSymmetric() { + const Tensor* b_zp_constant_tensor = GetConstantInputTensor(GetBZeroPointIdx()); + if (b_zp_constant_tensor != nullptr) { + assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); + const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); + const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); + can_use_dynamic_quant_mlas_ = std::none_of(zp_bytes, zp_bytes + zp_size_in_bytes, + [](std::byte v) { return v != std::byte{0}; }); + return can_use_dynamic_quant_mlas_; + } + + const auto input_defs = Info().node().InputDefs(); + const int b_zp_idx = GetBZeroPointIdx(); + const bool b_zp_input_exists = b_zp_idx >= 0 && + static_cast(b_zp_idx) < input_defs.size() && + input_defs[b_zp_idx]->Exists(); + can_use_dynamic_quant_mlas_ = !b_zp_input_exists; + return can_use_dynamic_quant_mlas_; + } + /* + Heper method to check the validity of the scales tensor for Arm® KleidiAI™ dynamic qantization. + Scales are invalid and can_use_dynamic_quant_mlas_ flag is false returns if the float scales are non-finite or non-positive. + Otherwise can_use_dynamic_quant_mlas_ flag returned true. + */ + bool AreScalesValid(const Tensor& b_scale_tensor) { + const auto bs = b_scale_tensor.DataAsSpan(); + const bool has_invalid = + std::any_of(bs.begin(), bs.end(), + [](float s) { return !std::isfinite(s) || s <= 0.0f; }); + + if (has_invalid) { + can_use_dynamic_quant_mlas_ = false; + } + return can_use_dynamic_quant_mlas_; + } + /* + Helper to promote a 1D tensor to 2D, for Arm® KleidiAI™ dynamic qantization, if necessary. Returns false if the tensor rank is 0. + */ + bool PromoteBShapeIfNeeded() { + if (b_shape_.NumDimensions() == 0) { + return false; // rank-0 tensor is not supported + } + + if (b_shape_.NumDimensions() == 1) { + TensorShapeVector expanded{1, b_shape_[0]}; + b_shape_ = TensorShape(expanded); + } + + return true; + } + /* + Helper method to check the shape policy of the tensor B is passes for Arm® KleidiAI™ dynamic quantization. + The shape should be at least 2D and all the dimentions except the last two should be 1. 1D tensor is promoted to 2D. + */ + bool IsBShapeSupportedForDynamicQuant(const TensorShape& tensor_shape) { + b_shape_ = tensor_shape; + if (!PromoteBShapeIfNeeded()) { + return false; + } + + for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) { + if (b_shape_[i] != 1) { + return false; + } + } + b_shape_ = tensor_shape; + return true; + } + /* + Checks against the constant initilized tensor index and returns the constant tensor if present. + Returns nullptr if index is invalid or the tensor is not hold by the kernel instance. + */ + const Tensor* GetConstantInputTensor(int input_idx) const { + if (input_idx < 0) { + return nullptr; + } + const OrtValue* ort_value = nullptr; + if (!Info().TryGetConstantInput(input_idx, &ort_value)) { + return nullptr; + } + + return &ort_value->Get(); + } + + bool dynamic_quant_mlas_bias_data_was_packed_{false}; +#endif + // Check if quantization parameter of B is supported. // It should be in one of the formats below: // 1. Scalar diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 30b0c0fcf73c3..9584c90f264de 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -11,6 +11,7 @@ #include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" +#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO #include #include @@ -489,5 +490,163 @@ TEST(MatMulIntegerToFloat, MatMulInteger_With_ZeroPoint) { test_case({15, 14, 13}, {15, 13, 27}, {15, 1, 27}); } +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + +static bool HasArmSME() { + return (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME() || MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()); +} + +// Helper to build a tiny 2x3×4 case we reuse. +struct KleidiDynMatMulData { + static constexpr int64_t M = 2; + static constexpr int64_t K = 4; + static constexpr int64_t N = 3; + + std::vector a = { + 1.f, 2.f, 3.f, 4.f, + -1.f, -2.f, -3.f, -4.f}; + std::vector b = { + 1, 0, -1, + 2, -1, 0, + 0, 1, 2, + -2, 0, 1}; + std::vector b_scale = {0.5f, 0.25f, 0.125f}; + std::vector b_zp = {0, 0, 0}; + + std::vector Reference(float bias0, float bias1, float bias2) const { + std::vector out(M * N, 0.f); + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.f; + for (int64_t k = 0; k < K; ++k) { + const float b_val = (static_cast(b[k * N + n]) - b_zp[n]) * b_scale[n]; + sum += a[m * K + k] * b_val; + } + const float bias = (n == 0 ? bias0 : n == 1 ? bias1 + : bias2); + out[m * N + n] = sum + bias; + } + } + return out; + } + std::vector Reference3D(float bias0, float bias1, float bias2, int64_t leading = 1) const { + auto base = Reference(bias0, bias1, bias2); + std::vector out; + out.reserve(leading * M * N); + for (int64_t i = 0; i < leading; ++i) { + out.insert(out.end(), base.begin(), base.end()); + } + return out; + } +}; + +// 1. Bias provided as initializer -> Kleidi packs bias and skips runtime add. +TEST(DynamicQuantizeMatMul, KleidiBiasInitializer) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + const std::vector bias = {0.25f, -0.5f, 1.125f}; + auto expected = data.Reference(bias[0], bias[1], bias[2]); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true /*initializer*/); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true /*initializer*/); + test.AddInput("bias", {data.N}, bias, true /*initializer*/); + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.2f); + test.Run(); +} + +// 2. Bias as runtime tensor -> exercise deferred bias add branch. +TEST(DynamicQuantizeMatMul, KleidiBiasRuntime) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + const std::vector bias = {1.0f, 0.0f, -0.75f}; + auto expected = data.Reference(bias[0], bias[1], bias[2]); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddInput("bias", {data.N}, bias, false /*runtime*/); + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.2f); + test.Run(); +} + +// 3. Non-zero zero-points -> Kleidi pack rejected, falls back to generic path. +TEST(DynamicQuantizeMatMul, KleidiRejectsNonZeroZeroPoint) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + data.b_zp = {1, 0, 0}; // violates symmetry, Kleidi path disabled + auto expected = data.Reference(0.f, 0.f, 0.f); // still compare to reference + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp); + test.AddOptionalInputEdge(); // no bias + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.2f); + test.Run(); // succeeds, but exercises the “fallback” branch +} + +// 4. Invalid scales -> Kleidi pack rejected. +TEST(DynamicQuantizeMatMul, KleidiRejectsInvalidScale) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + data.b_scale[1] = 0.f; // invalid + auto expected = data.Reference(0.f, 0.f, 0.f); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddOptionalInputEdge(); + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.2f); + test.Run(); +} + +// 5. Unsupported B-shape (e.g., 3D) -> Kleidi pack rejected. +TEST(DynamicQuantizeMatMul, KleidiRejectsUnsupportedBShape) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + std::vector b_3d; + b_3d.reserve(2 * data.b.size()); + b_3d.insert(b_3d.end(), data.b.begin(), data.b.end()); + b_3d.insert(b_3d.end(), data.b.begin(), data.b.end()); + std::vector b_shape = {2, data.K, data.N}; + + std::vector b_scale_3d; + b_scale_3d.reserve(2 * data.N); + b_scale_3d.insert(b_scale_3d.end(), data.b_scale.begin(), data.b_scale.end()); + b_scale_3d.insert(b_scale_3d.end(), data.b_scale.begin(), data.b_scale.end()); + + std::vector b_zp_3d; + b_zp_3d.reserve(2 * data.N); + b_zp_3d.insert(b_zp_3d.end(), data.b_zp.begin(), data.b_zp.end()); + b_zp_3d.insert(b_zp_3d.end(), data.b_zp.begin(), data.b_zp.end()); + + auto expected = data.Reference3D(0.f, 0.f, 0.f, /*leading=*/2); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", b_shape, b_3d, true); + test.AddInput("b_scale", {2, 1, data.N}, b_scale_3d, true); + test.AddInput("b_zero_point", {2, 1, data.N}, b_zp_3d, true); + + test.AddOptionalInputEdge(); + test.AddOutput("Y", {2, data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.2f); + test.Run(); +} + +#endif // USE_KLEIDIAI + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp index 83f5b7f106d3e..ac946cd7d0828 100644 --- a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp @@ -6,8 +6,12 @@ #include "mlas.h" #include "test_util.h" +#include "core/mlas/inc/mlas.h" -class MlasDynamicQgemmTest { +#include +#include + +class MlasDynamicQgemmTestBase { private: MatrixGuardBuffer buffer_a; MatrixGuardBuffer buffer_bf; @@ -15,10 +19,13 @@ class MlasDynamicQgemmTest { MatrixGuardBuffer buffer_c; MatrixGuardBuffer buffer_c_ref; - public: - void Test(size_t M, size_t N, size_t K, size_t BatchSize) { - // Setup buffers for holding various data + protected: + void Run(size_t M, size_t N, size_t K, size_t BatchSize, + MLAS_THREADPOOL* threadpool, bool require_threadpool, const char* run_tag) { + if (require_threadpool && threadpool == nullptr) + GTEST_SKIP() << "Dynamic QGEMM threading path requested but no MLAS thread pool is available."; + // Setup buffers for holding various data float* A = buffer_a.GetBuffer(M * K * BatchSize); // Buffer for holding floating point version of weight matrix float* Bf = buffer_bf.GetBuffer(K * N * BatchSize); @@ -36,6 +43,9 @@ class MlasDynamicQgemmTest { // Quantize Bf → Bq and compute per-column scale and bias per batch std::vector> b_scale_batches(BatchSize, std::vector(N)); std::vector> b_bias_batches(BatchSize, std::vector(N, 0.0f)); + std::vector> a_quant_batches(BatchSize, std::vector(M * K)); + std::vector> a_scale_batches(BatchSize, std::vector(M)); + std::vector> a_zero_point_batches(BatchSize, std::vector(M)); for (size_t b = 0; b < BatchSize; ++b) { for (size_t n = 0; n < N; ++n) { @@ -58,6 +68,42 @@ class MlasDynamicQgemmTest { } } + // Quantize A rows to match the dynamic quantization performed by the kernel. + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t m = 0; m < M; ++m) { + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::lowest(); + for (size_t k = 0; k < K; ++k) { + float v = A[b * M * K + m * K + k]; + min_val = std::min(min_val, v); + max_val = std::max(max_val, v); + } + float rmin = std::min(0.0f, min_val); + float rmax = std::max(0.0f, max_val); + float inv_scale = (rmax == rmin) ? 1.0f : 255.0f / (rmax - rmin); + float scale = inv_scale ? 1.0f / inv_scale : 0.0f; + float descaled_min = rmin * inv_scale; + float descaled_max = rmax * inv_scale; + float zero_point_from_min_error = -128.0f + descaled_min; + float zero_point_from_max_error = 127.0f + descaled_max; + float zero_point = (zero_point_from_min_error + zero_point_from_max_error > 0.0f) + ? (-128.0f - descaled_min) + : (127.0f - descaled_max); + zero_point = std::clamp(zero_point, -128.0f, 127.0f); + int32_t zp = static_cast(std::nearbyint(zero_point)); + + a_scale_batches[b][m] = scale; + a_zero_point_batches[b][m] = zp; + + for (size_t k = 0; k < K; ++k) { + float v = A[b * M * K + m * K + k]; + int32_t q = static_cast(std::round(v * inv_scale)) + zp; + q = std::clamp(q, -128, 127); + a_quant_batches[b][m * K + k] = static_cast(q); + } + } + } + // Prepare kernel parameters MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K}; std::vector packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K)); @@ -78,16 +124,16 @@ class MlasDynamicQgemmTest { params[b].PackedB = packed_b; } - // call MlasDynamicQGemmBatch Function - MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr); - // Compute reference result for (size_t b = 0; b < BatchSize; ++b) { for (size_t m = 0; m < M; ++m) { for (size_t n = 0; n < N; ++n) { float sum = 0.0f; + const float a_scale = a_scale_batches[b][m]; + const int32_t a_zero_point = a_zero_point_batches[b][m]; for (size_t k = 0; k < K; ++k) { - float a = A[b * M * K + m * K + k]; + int32_t a_q = static_cast(a_quant_batches[b][m * K + k]); + float a = static_cast(a_q - a_zero_point) * a_scale; float bval = static_cast(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n]; sum += a * bval; } @@ -96,45 +142,73 @@ class MlasDynamicQgemmTest { } } + std::fill(C, C + M * N * BatchSize, 0.0f); + MlasDynamicQGemmBatch(shape, params.data(), BatchSize, threadpool); + // Validate results - for (size_t i = 0; i < M * N * BatchSize; ++i) { - float abs_c_ref = std::abs(CRef[i]); - float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; - float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); - float abs_tol = 3.0f; - float allowed = std::max(rel_tol, abs_tol); - float diff = std::abs(C[i] - CRef[i]); - ASSERT_LE(diff, allowed); - } + auto validate = [&](const char* tag) { + SCOPED_TRACE(tag); + for (size_t i = 0; i < M * N * BatchSize; ++i) { + float abs_c_ref = std::abs(CRef[i]); + float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; + float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); + float abs_tol = 3.0f; + float allowed = std::max(rel_tol, abs_tol); + float diff = std::abs(C[i] - CRef[i]); + ASSERT_LE(diff, allowed); + } + }; + + validate(run_tag); + } +}; + +class MlasDynamicQgemmSingleThreadTest : public MlasDynamicQgemmTestBase { + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + if (!MlasIsDynamicQGemmAvailable()) + GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME or SME2 but it was not detected. Skipping test."; + Run(M, N, K, BatchSize, /*threadpool*/ nullptr, /*require_threadpool*/ false, "SingleThread"); } + static const char* GetTestSuiteName() { return "DynamicQgemmSingleThread"; } +}; - static const char* GetTestSuiteName() { - return "DynamicQgemm"; +class MlasDynamicQgemmThreadPoolTest : public MlasDynamicQgemmTestBase { + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + if (!MlasIsDynamicQGemmAvailable()) + GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME or SME2 but it was not detected. Skipping test."; + MLAS_THREADPOOL* tp = GetMlasThreadPool(); + if (!tp) GTEST_SKIP() << "Mlas thread pool not available"; + Run(M, N, K, BatchSize, tp, /*require_threadpool*/ true, "ThreadPool"); } + static const char* GetTestSuiteName() { return "DynamicQgemmThreaded"; } }; -class DynamicQgemmExecuteTest : public MlasTestFixture { +template +class DynamicQgemmExecuteTest : public MlasTestFixture { public: DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize) : M_(M), N_(N), K_(K), BatchSize_(BatchSize) {} void TestBody() override { - this->mlas_tester->Test(M_, N_, K_, BatchSize_); + MlasTestFixture::mlas_tester->Test(M_, N_, K_, BatchSize_); } static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) { std::stringstream ss; ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize; - std::string test_name = ss.str(); testing::RegisterTest( - MlasDynamicQgemmTest::GetTestSuiteName(), + TMlasTester::GetTestSuiteName(), test_name.c_str(), nullptr, test_name.c_str(), __FILE__, __LINE__, - [=]() -> MlasTestFixture* { + [=]() -> MlasTestFixture* { return new DynamicQgemmExecuteTest(M, N, K, BatchSize); }); @@ -158,11 +232,10 @@ class DynamicQgemmExecuteTest : public MlasTestFixture { size_t M_, N_, K_, BatchSize_; }; -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - // Only register tests if MlasDynamicQGemmBatch() has an implementation available. - if (!MlasIsDynamicQGemmAvailable()) { - return size_t{0}; - } +static UNUSED_VARIABLE bool added_single = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +}); - return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +static UNUSED_VARIABLE bool added_threaded = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); });