From f3594ed138d57eb9bb54fe7a248840103589b652 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 19 Mar 2026 17:36:15 +0000 Subject: [PATCH 1/7] Extend DQ->MatMulNBits fusion to support Gemm + per-tensor/per-channel quantization --- .../onnxruntime_session_options_config_keys.h | 7 + .../core/optimizer/graph_transformer_utils.cc | 14 +- .../selectors_actions/qdq_actions.cc | 220 +++++- .../selectors_actions/qdq_actions.h | 4 +- .../qdq_selector_action_transformer.cc | 23 +- .../qdq_selector_action_transformer.h | 3 +- .../selectors_actions/qdq_selectors.cc | 206 +++++- .../selectors_actions/qdq_selectors.h | 6 +- .../qdq_matmulnbits_transformer_test.cc | 645 +++++++++++++++++- 9 files changed, 1067 insertions(+), 61 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index f0a99bc11c8b3..a9d9ac8323b16 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -391,6 +391,13 @@ static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_k // If not provided, default is 4. static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; +// Block size used when converting per-tensor or per-axis DQ + MatMul to MatMulNBits. +// Only applies to DQ nodes without an existing block_size attribute (i.e., per-tensor or per-axis quantization). +// Positive value: explicit block_size (must be power-of-2 and >= 16, e.g., 16, 32, 64, 128). +// "0" or not provided: use default block_size of 32. +// "-1": heuristic - largest power-of-2 <= min(K, 256) that minimizes padding. +static const char* const kOrtSessionOptionsQDQMatMulNBitsBlockSize = "session.qdq_matmulnbits_block_size"; + // Enable the DQ->MatMulNBits fusion graph transformer. // "0": disabled (default). "1": enabled. // This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered. diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 640848d47fe93..3b2677816d136 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -347,6 +347,10 @@ InlinedVector> GenerateTransformers( ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "4")); + const int64_t qdq_matmulnbits_block_size = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize, + "0")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -363,7 +367,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + qdq_matmulnbits_block_size)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -504,6 +509,10 @@ InlinedVector> GenerateTransformersForMinimalB ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "4")); + const int64_t qdq_matmulnbits_block_size = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize, + "0")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; @@ -511,7 +520,8 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + qdq_matmulnbits_block_size)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index fdc0818e8437b..05d7dde81b2e6 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -42,6 +42,67 @@ bool IsDQWeightSigned(int32_t dt_weight) { dt_weight == TensorProto::INT8; } +// Compute the effective block_size for per-tensor/per-channel DQ nodes that lack a block_size attribute. +// session_block_size: 0 = default (32), positive = explicit, -1 = min-padding heuristic. +int64_t ComputeEffectiveBlockSize(int64_t session_block_size, int64_t K) { + // MatMulNBits CPU kernel currently only supports block_size in [16, 256] correctly. + constexpr int64_t kMinBlockSize = 16; + constexpr int64_t kMaxBlockSize = 256; + + if (session_block_size > 0) { + // Explicit block_size — must be power-of-2 and within [kMinBlockSize, kMaxBlockSize]. + ORT_ENFORCE(session_block_size >= kMinBlockSize && + ((session_block_size & (session_block_size - 1)) == 0), + "Explicit qdq_matmulnbits_block_size must be a power-of-2 and >= ", + kMinBlockSize, ", got: ", session_block_size); + ORT_ENFORCE(session_block_size <= kMaxBlockSize, + "Explicit qdq_matmulnbits_block_size must be <= ", + kMaxBlockSize, ", got: ", session_block_size); + return session_block_size; + } + + if (session_block_size == -1) { + // Heuristic: largest power-of-2 <= min(K, kMaxBlockSize) that minimizes padding. + // Capped at kMaxBlockSize because CPU EP only supports block_size up to kMaxBlockSize correctly. + // We want ceil(K / B) * B - K to be minimized (least wasted padding). + int64_t best_bs = kMinBlockSize; + int64_t best_padding = (((K + (kMinBlockSize - 1)) / kMinBlockSize) * kMinBlockSize) - K; + for (int64_t bs = kMinBlockSize * 2; bs <= std::min(K, kMaxBlockSize); bs *= 2) { + int64_t padding = (((K + bs - 1) / bs) * bs) - K; + if (padding <= best_padding) { + best_padding = padding; + best_bs = bs; + } + } + return best_bs; + } + + // Default (session_block_size == 0): use 32 + return 32; +} + +// Get the DQ block_size: from the attribute if blockwise, or computed for per-tensor/per-channel. +int64_t GetEffectiveBlockSize(const Node& dq_node, int64_t block_size_for_non_blockwise) { + const auto& dq_attrs = dq_node.GetAttributes(); + const auto bs_iter = dq_attrs.find("block_size"); + if (bs_iter != dq_attrs.end() && bs_iter->second.i() > 0) { + return bs_iter->second.i(); + } + + // Derive K from the weight input shape if available. Shape information may be missing even + // when the weight is a constant initializer, so guard against nullptrs / unknown dims. + int64_t K = 32; // reasonable default consistent with ComputeEffectiveBlockSize default + const auto* weight_arg = dq_node.InputDefs()[0]; + if (weight_arg != nullptr) { + const auto* shape = weight_arg->Shape(); + if (shape != nullptr && shape->dim_size() > 0 && shape->dim(0).has_dim_value()) { + K = static_cast(shape->dim(0).dim_value()); + } + } + + return ComputeEffectiveBlockSize(block_size_for_non_blockwise, K); +} + // Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits. // Used by DQMatMulToMatMulNBitsAction. struct TransposedQuantizedTensors { @@ -56,16 +117,17 @@ struct TransposedQuantizedTensors { // Transpose DQ weight/scale/zp tensors from column-wise layout to MatMulNBits layout via MLAS. // default_zp_name_prefix: prefix for auto-generated zero-point name when unsigned type has no explicit zp. +// effective_block_size: the block_size to use for MatMulNBits (may differ from DQ's block_size for per-tensor/per-channel). Status TransposeDQWeightsForMatMulNBits( Graph& graph, const Node& dq_node, const std::string& default_zp_name_prefix, concurrency::ThreadPool* intra_op_thread_pool, + int64_t effective_block_size, TransposedQuantizedTensors& result) { const auto* weight_arg = dq_node.InputDefs()[0]; const auto* scale_arg = dq_node.InputDefs()[1]; const auto* zp_arg = dq_node.InputDefs().size() > 2 ? dq_node.InputDefs()[2] : nullptr; - const auto& attrs = dq_node.GetAttributes(); const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), @@ -80,7 +142,7 @@ Status TransposeDQWeightsForMatMulNBits( auto K = weight_arg->Shape()->dim(0).dim_value(); auto N = weight_arg->Shape()->dim(1).dim_value(); - auto block_size = attrs.at("block_size").i(); + auto block_size = effective_block_size; int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); auto bits = DQWeightBits(dt_weight); auto quant_num = (K + block_size - 1) / block_size; @@ -94,8 +156,100 @@ Status TransposeDQWeightsForMatMulNBits( std::optional zp_src; auto cpu_allocator = CPUAllocator::DefaultInstance(); + // Determine if scale/zp need expansion from per-tensor/per-channel to blockwise [quant_num, N]. + const bool is_blockwise = (scale_tensor_proto->dims_size() == 2); + std::optional expanded_scale; + std::optional expanded_zp; + + if (!is_blockwise) { + // Expand scale to [quant_num, N] + expanded_scale.emplace(scale_type, TensorShape{quant_num, N}, cpu_allocator); + bool is_per_tensor = (scale_tensor_proto->dims_size() == 0); + + auto expand_scale = [&](auto* src_data, auto* dst_data) { + if (is_per_tensor) { + auto val = src_data[0]; + for (int64_t i = 0; i < quant_num * N; ++i) { + dst_data[i] = val; + } + } else { + // Per-channel: scale shape [N], replicate across quant_num blocks + for (int64_t b = 0; b < quant_num; ++b) { + for (int64_t n = 0; n < N; ++n) { + dst_data[b * N + n] = src_data[n]; + } + } + } + }; + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + expand_scale(scale_src.data(), expanded_scale->MutableData()); + } else { + expand_scale(scale_src.data(), expanded_scale->MutableData()); + } + + // Expand zp if present + if (zp_tensor_proto) { + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); + // Allocate as uint8 with enough bytes to hold quant_num*N packed sub-byte elements. + int64_t expanded_zp_bytes = (quant_num * N * bits + 7) / 8; + expanded_zp.emplace(uint8_type, TensorShape{expanded_zp_bytes}, cpu_allocator); + + // For sub-byte types, the zp is packed in bytes. We need to expand element-wise. + // For 8-bit, each byte is one element. For 4-bit, 2 elements per byte. For 2-bit, 4 elements per byte. + const uint8_t* zp_bytes = zp_src->DataAsByteSpan().data(); + uint8_t* dst_zp_bytes = expanded_zp->MutableData(); + + auto get_element = [bits](const uint8_t* data, int64_t idx) -> uint8_t { + if (bits == 8) return data[idx]; + if (bits == 4) { + uint8_t byte = data[idx / 2]; + return (idx % 2 == 0) ? (byte & 0x0F) : (byte >> 4); + } + // bits == 2 + uint8_t byte = data[idx / 4]; + int shift = static_cast((idx % 4) * 2); + return (byte >> shift) & 0x03; + }; + + auto set_element = [bits](uint8_t* data, int64_t idx, uint8_t val) { + if (bits == 8) { + data[idx] = val; + return; + } + if (bits == 4) { + int64_t byte_idx = idx / 2; + if (idx % 2 == 0) { + data[byte_idx] = (data[byte_idx] & 0xF0) | (val & 0x0F); + } else { + data[byte_idx] = (data[byte_idx] & 0x0F) | ((val & 0x0F) << 4); + } + return; + } + // bits == 2 + int64_t byte_idx = idx / 4; + int shift = static_cast((idx % 4) * 2); + data[byte_idx] = (data[byte_idx] & ~(0x03 << shift)) | ((val & 0x03) << shift); + }; + + // Initialize expanded zp to 0 + memset(dst_zp_bytes, 0, expanded_zp->SizeInBytes()); + + for (int64_t b = 0; b < quant_num; ++b) { + for (int64_t n = 0; n < N; ++n) { + int64_t src_idx = is_per_tensor ? 0 : n; + uint8_t val = get_element(zp_bytes, src_idx); + set_element(dst_zp_bytes, b * N + n, val); + } + } + } + } + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); result.weight = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + // Zero-initialize: MLAS 4-bit transpose does not zero-pad when K < block_size, + // leaving uninitialized bytes in the last block's padding region. + memset(result.weight.MutableDataRaw(), 0, result.weight.SizeInBytes()); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); auto scale_size = (TensorShape{N, quant_num}).Size(); @@ -104,7 +258,13 @@ Status TransposeDQWeightsForMatMulNBits( std::string zp_dst_name; auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); - if (zp_tensor_proto) { + if (!is_blockwise && expanded_zp.has_value()) { + // Per-tensor/per-channel path with expanded zero-point + zp_dst_name = graph.GenerateNodeArgName( + (zp_arg ? zp_arg->Name() : default_zp_name_prefix + "_zero_point") + "_T"); + result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + } else if (zp_tensor_proto) { + // Blockwise path with explicit zero-point zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); @@ -116,10 +276,15 @@ Status TransposeDQWeightsForMatMulNBits( // Dispatch MLAS transpose based on scale type, bits, and signedness. auto transpose = [&](auto* scale_data, auto* scale_dst_data) { - using ScaleType = std::remove_pointer_t; + using ScaleType = std::remove_const_t>; bool is_signed = IsDQWeightSigned(dt_weight); const uint8_t* src_w = weight_src.DataAsByteSpan().data(); - const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; + const uint8_t* src_zp = nullptr; + if (expanded_zp.has_value()) { + src_zp = expanded_zp->Data(); + } else if (zp_src.has_value()) { + src_zp = zp_src->DataAsByteSpan().data(); + } uint8_t* dst_w = result.weight.MutableData(); uint8_t* dst_zp = result.zero_point ? result.zero_point->MutableData() : nullptr; int K_int = static_cast(K); @@ -148,9 +313,11 @@ Status TransposeDQWeightsForMatMulNBits( }; if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - transpose(scale_src.data(), result.scale.MutableData()); + const float* s_data = expanded_scale.has_value() ? expanded_scale->Data() : scale_src.data(); + transpose(s_data, result.scale.MutableData()); } else { - transpose(scale_src.data(), result.scale.MutableData()); + const MLFloat16* s_data = expanded_scale.has_value() ? expanded_scale->Data() : scale_src.data(); + transpose(s_data, result.scale.MutableData()); } result.weight_proto = utils::TensorToTensorProto(result.weight, weight_dst_name, true); @@ -430,7 +597,8 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) + concurrency::ThreadPool* intra_op_thread_pool, + int64_t block_size_for_non_blockwise) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -440,7 +608,8 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool} { + intra_op_thread_pool_{intra_op_thread_pool}, + block_size_for_non_blockwise_{block_size_for_non_blockwise} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -449,7 +618,6 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) NodeAttributes extra_attributes; const auto* dq_node = runtime_state.selected_nodes.Input(0); - auto& attrs = dq_node->GetAttributes(); const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); @@ -457,7 +625,8 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); int32_t dt_weight = dq_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); utils::SetNodeAttribute(utils::MakeAttribute("bits", DQWeightBits(dt_weight)), extra_attributes); - utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); + int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", effective_bs), extra_attributes); return extra_attributes; } @@ -467,9 +636,11 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, Node& replacement_node) const { const auto* dq_node = selected_nodes.Input(0); + int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_); + TransposedQuantizedTensors transposed; ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( - graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, transposed)); + graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, effective_bs, transposed)); auto& input_defs = replacement_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); @@ -483,6 +654,31 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, replacement_node.MutableInputArgsCount().push_back(1); } + // If the target was Gemm, strip Gemm-specific attributes from the replacement MatMulNBits node + // and wire the bias (if present) to MatMulNBits input 5. + const auto& target = selected_nodes.Target(); + if (target.OpType() == "Gemm") { + replacement_node.ClearAttribute("alpha"); + replacement_node.ClearAttribute("beta"); + replacement_node.ClearAttribute("transA"); + replacement_node.ClearAttribute("transB"); + + // Wire Gemm bias to MatMulNBits input 5 (bias slot). + // The bias can be a direct float tensor or the output of a DQ node. + const auto& target_inputs = target.InputDefs(); + if (target_inputs.size() > 2 && target_inputs[2] && target_inputs[2]->Exists()) { + // MatMulNBits input layout: 0:A, 1:B, 2:scales, 3:zp(opt), 4:g_idx(opt), 5:bias(opt) + // Pad with empty NodeArgs up to position 5. + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + while (input_defs.size() < 5) { + input_defs.push_back(&empty_arg); + replacement_node.MutableInputArgsCount().push_back(1); + } + input_defs.push_back(const_cast(target_inputs[2])); + replacement_node.MutableInputArgsCount().push_back(1); + } + } + return Status::OK(); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 02a8353707599..f0b1e17a7ffe0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -86,7 +86,8 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); + concurrency::ThreadPool* intra_op_thread_pool, + int64_t block_size_for_non_blockwise = 0); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -105,6 +106,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + const int64_t block_size_for_non_blockwise_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 8cab6911646f2..c88ae9b8c4782 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -296,15 +296,19 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is 2/4/8-bit int (int2/uint2, int4/uint4, int8/uint8). DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + // Also supports per-tensor and per-channel (axis=1) quantized DQ weights by expanding + // scales/zero-points to blockwise format using qdq_matmulnbits_block_size. const std::string action_name{"DQMatMulToMatMulNBits"}; std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + qdq_matmulnbits_block_size); #if !defined(ORT_MINIMAL_BUILD) // Include "" (empty string) to match nodes not yet assigned to an EP. @@ -315,7 +319,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider, ""}; std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, - {{"MatMul", {}}}, + {{"MatMul", {}}, + {"Gemm", {}}}, std::move(selector), std::move(action)); @@ -370,7 +375,8 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { SelectorActionRegistry CreateSelectorActionRegistry( bool is_int8_allowed, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -384,7 +390,8 @@ SelectorActionRegistry CreateSelectorActionRegistry( WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + qdq_matmulnbits_block_size); return qdq_selector_action_registry; } @@ -395,11 +402,13 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( bool is_int8_allowed, const SatApplyContextVariant& apply_context, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) + concurrency::ThreadPool* intra_op_thread_pool, + int64_t qdq_matmulnbits_block_size) : SelectorActionTransformer{ "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool), + intra_op_thread_pool, + qdq_matmulnbits_block_size), apply_context, // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index dce1cd44fd3ea..8294c839cfe42 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -29,7 +29,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + int64_t qdq_matmulnbits_block_size = 0); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 8a00fe11ff3fd..ef9e1b0cad490 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -6,6 +6,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/graph/graph.h" +#include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -558,11 +559,14 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& } } -// Validate that a DQ node has the correct structure for MatMulNBits fusion: -// - weight type is 2/4/8-bit int, scale type is float or float16 -// - blockwise quantization along axis 0, block_size is power-of-2 and >= 16 -// - weight/scale/zp are constant initializers with rank 2 and consistent shapes -static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq_node) { +// Validate that a DQ node has the correct structure for MatMulNBits fusion. +// Supports three quantization granularities: +// - Blockwise: axis=0, block_size >= 16 and power-of-2, scale/zp rank 2 +// - Per-tensor: scale is scalar (rank 0), no block_size attribute +// - Per-channel (axis=1): scale is 1D with shape [N], weight is 2D [K,N], no block_size attribute +// In all cases: weight type is 2/4/8-bit int, scale type is float or float16, +// weight/scale/zp are constant initializers. +static bool ValidateDQForMatMulNBits(const Graph& graph, const Node& dq_node) { const auto* weight_arg = dq_node.InputDefs()[0]; const auto* scale_arg = dq_node.InputDefs()[1]; const auto* zero_point_arg = dq_node.InputDefs().size() == 3 ? dq_node.InputDefs()[2] : nullptr; @@ -578,22 +582,6 @@ static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq return false; } - // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 - const auto& dq_attrs = dq_node.GetAttributes(); - if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { - return false; - } - - const auto a_iter = dq_attrs.find("block_size"); - if (a_iter == dq_attrs.end()) { - return false; - } - - auto block_size = a_iter->second.i(); - if (block_size < 16 || ((block_size - 1) & block_size)) { - return false; - } - // weight, scale and zero points (if exists) must be constants const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); @@ -607,18 +595,124 @@ static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq return false; } - // weight, scale and zero points (if exists) must have the rank 2 - if (weight_tensor_proto->dims_size() != 2 || scale_tensor_proto->dims_size() != 2 || - (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + // weight must be rank 2 + if (weight_tensor_proto->dims_size() != 2) { return false; } - // check weight, scale and zero points (if exists) shapes - if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || - weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || - (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || - zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + const auto& dq_attrs = dq_node.GetAttributes(); + const auto block_size_iter = dq_attrs.find("block_size"); + const bool has_block_size = block_size_iter != dq_attrs.end() && block_size_iter->second.i() > 0; + + if (has_block_size) { + // --- Blockwise path (existing logic) --- + if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + auto block_size = block_size_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return false; + } + + if (scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + } else { + // --- Per-tensor or per-channel path --- + int scale_rank = scale_tensor_proto->dims_size(); + auto N = weight_tensor_proto->dims()[1]; + + if (scale_rank == 0) { + // Per-tensor: scalar scale, optional scalar zp + if (zp_tensor_proto && zp_tensor_proto->dims_size() != 0) { + return false; + } + } else if (scale_rank == 1 && scale_tensor_proto->dims()[0] == N) { + // Per-channel (axis=1): scale shape [N], axis must be 1 + const auto a_iter = dq_attrs.find("axis"); + // DQ default axis is 1, so absent axis is OK + if (a_iter != dq_attrs.end() && a_iter->second.i() != 1) { + return false; + } + if (zp_tensor_proto && (zp_tensor_proto->dims_size() != 1 || zp_tensor_proto->dims()[0] != N)) { + return false; + } + } else { + // Unsupported quantization granularity + return false; + } + } + + return true; +} + +// Validate Gemm attributes for DQ->MatMulNBits fusion. +// Gemm must be equivalent to MatMul: alpha=1, transA=0, transB=0. +// If bias exists, beta must be 1 and bias shape must be [N]. +static bool ValidateGemmForDQMatMulNBits(const Graph& graph, const Node& gemm_node, const Node& weight_dq_node) { + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm_node, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) + return false; + if (const auto* trans_a = graph_utils::GetNodeAttribute(gemm_node, "transA"); + trans_a && trans_a->i() != 0) + return false; + if (const auto* trans_b = graph_utils::GetNodeAttribute(gemm_node, "transB"); + trans_b && trans_b->i() != 0) return false; + + const auto& inputs = gemm_node.InputDefs(); + if (inputs.size() > 2 && inputs[2] && inputs[2]->Exists()) { + // Bias exists — beta must be 1.0 + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm_node, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) + return false; + + // Bias shape must be [N] where N = weight dim 1. Prefer reading N and + // bias length from constant initializers when available, and fall back to + // NodeArg::Shape(). + const auto* weight_arg = weight_dq_node.InputDefs()[0]; + const auto* weight_initializer = graph.GetConstantInitializer(weight_arg->Name(), true); + int64_t N = -1; + + if (weight_initializer) { + if (weight_initializer->dims_size() != 2) { + return false; + } + N = weight_initializer->dims(1); + } else { + const auto* weight_shape = weight_arg->Shape(); + if (!weight_shape || weight_shape->dim_size() != 2 || + !utils::HasDimValue(weight_shape->dim(1))) { + return false; + } + N = weight_shape->dim(1).dim_value(); + } + + const auto* bias_arg = inputs[2]; + const auto* bias_initializer = graph.GetConstantInitializer(bias_arg->Name(), true); + + if (bias_initializer) { + if (bias_initializer->dims_size() != 1 || + bias_initializer->dims(0) != N) { + return false; + } + } else { + const auto* bias_shape = bias_arg->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1 || + !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != N) { + return false; + } + } } return true; @@ -637,18 +731,55 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod } const auto& graph = graph_viewer.GetGraph(); + const bool is_gemm = node.OpType() == "Gemm"; + + if (is_gemm) { + // Gemm: accept 1 DQ (weight only) or 2 DQs (weight + bias). + if (dq_nodes.size() < 1 || dq_nodes.size() > 2) { + return false; + } + } else { + // MatMul: exactly 1 DQ input + if (dq_nodes.size() != 1) { + return false; + } + } - // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output - if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + // Find the weight DQ node — the one feeding input 1 (B) + const Node* weight_dq = nullptr; + for (const auto* dq : dq_nodes) { + if (node.InputDefs()[1] == dq->OutputDefs()[0]) { + weight_dq = dq; + break; + } + } + + if (!weight_dq) { return false; } - // DQ must be MatMul's the second input - if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + // Weight DQ must have exactly 1 output edge and not be a graph output + if (!optimizer_utils::CheckOutputEdges(graph, *weight_dq, 1)) { return false; } - return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]); + if (is_gemm) { + // If there's a second DQ node (for bias), it must feed input 2 + if (dq_nodes.size() == 2) { + const Node* bias_dq = (dq_nodes[0] == weight_dq) ? dq_nodes[1] : dq_nodes[0]; + if (node.InputDefs().size() <= 2 || !node.InputDefs()[2] || + node.InputDefs()[2] != bias_dq->OutputDefs()[0]) { + return false; + } + } + + // Validate Gemm attributes (alpha=1, transA=0, transB=0, beta=1 if bias) + if (!ValidateGemmForDQMatMulNBits(graph, node, *weight_dq)) { + return false; + } + } + + return ValidateDQForMatMulNBits(graph, *weight_dq); } bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, @@ -701,6 +832,13 @@ void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex); } +void DQMatMulToMatMulNBitsSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { + // Keep only the weight DQ (first entry). If a Gemm has a bias DQ, it will be in + // position 1 — trim it so RemoveNodes does not delete it. The bias DQ's output + // is wired to MatMulNBits input 5 in ProcessNewNode. + builder.input_nodes.resize(1); +} + bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 79c374b301442..10d307b4a003c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -454,11 +454,15 @@ class MatMulSelector : public BaseSelector { compatible_providers) {} }; -// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +// Convert "1 DQ node for input B -> MatMul/Gemm" to "MatMulNBits" class DQMatMulToMatMulNBitsSelector : public BaseSelector { public: explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) : BaseSelector(std::make_unique(), compatible_providers) {} + + // Only keep the weight DQ in the selection. Any bias DQ (for Gemm) is excluded + // so that RemoveNodes does not remove it — its output is wired through to MatMulNBits. + void UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const override; }; // Input: DQ nodes for A, B and optional C diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 5d7eda39be271..b2344858dd1a1 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -179,7 +179,7 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { @@ -295,7 +295,7 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); if constexpr (use_zp) { @@ -727,7 +727,7 @@ RunDQMatMulFP16Converted(const std::vector& input1_shape, utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); - auto scale_shape = std::vector{weight_shape}; + std::vector scale_shape = weight_shape; scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); @@ -780,6 +780,645 @@ TEST(QDQTransformerTests, DQMatMulFP16ConvertedToMatMulNBits) { RunDQMatMulFP16Converted({12, 32}, {32, 16}, 0, 16, 0); } +// Per-tensor DQ -> MatMul conversion to MatMulNBits +// DQ has scalar scale (and optional scalar zero-point), no block_size attribute. +// Input1 +// | DQ(per-tensor) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulPerTensorConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + // Scalar scale (per-tensor) + auto* scale_arg = builder.MakeInitializer({}, {10.0f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(std::vector{}, T(1, 0), T(1, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 0.01 /*per_sample_tolerance - higher due to blockwise accumulation reordering*/, + 5e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorConvertedToMatMulNBits) { + // Per-tensor int4/uint4 with and without zero-point + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); + // With accuracy_level=1 + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 1); + RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 1); + // K not divisible by default block_size (32) + RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); + RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); + // N=1 (edge case: single column) + RunDQMatMulPerTensorConverted({12, 768}, {768, 1}, 0); + RunDQMatMulPerTensorConverted({12, 768}, {768, 1}, 0); +} + +// Per-channel (axis=1) DQ -> MatMul conversion to MatMulNBits +// DQ has 1D scale shape [N], axis=1, no block_size attribute. +// Input1 +// | DQ(per-channel axis=1) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulPerChannelConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + int64_t N = weight_shape[1]; + // 1D scale shape [N] for per-channel (axis=1) + auto* scale_arg = builder.MakeInitializer({N}, 8.0f, 12.0f); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(1)), attrs); + + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(std::vector{N}, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerChannelConvertedToMatMulNBits) { + // Per-channel int4/uint4 with and without zero-point + RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); + RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); + // With accuracy_level=1 + RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 1); + // K not divisible by default block_size (32) + RunDQMatMulPerChannelConverted({12, 37}, {37, 16}, 0); + RunDQMatMulPerChannelConverted({12, 37}, {37, 16}, 0); +} + +// Negative test: per-axis axis=0 with 1D scale should NOT fuse +template +void RunDQMatMulPerAxisAxis0NotConverted(const std::vector& input1_shape, + const std::vector& weight_shape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + int64_t K = weight_shape[0]; + // 1D scale shape [K] for per-axis axis=0 — should NOT match + auto* scale_arg = builder.MakeInitializer({K}, 8.0f, 12.0f); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), attrs); + + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn = [](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "0"); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerAxisAxis0NotConvertedToMatMulNBits) { + RunDQMatMulPerAxisAxis0NotConverted({12, 32}, {32, 16}); + RunDQMatMulPerAxisAxis0NotConverted({12, 32}, {32, 16}); +} + +// Per-tensor DQ -> MatMul with configurable block_size session option +template +void RunDQMatMulPerTensorWithBlockSize(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t block_size_option) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + + auto* scale_arg = builder.MakeInitializer({}, {10.0f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(std::vector{}, T(1, 0), T(1, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + + // Verify the MatMulNBits node has the expected block_size attribute + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "MatMulNBits") { + auto& attrs = node.GetAttributes(); + auto bs_iter = attrs.find("block_size"); + ASSERT_NE(bs_iter, attrs.end()); + int64_t expected_bs = block_size_option > 0 ? block_size_option : 32; // default is 32 + EXPECT_EQ(bs_iter->second.i(), expected_bs); + } + } + }; + + std::function add_session_options_fn = + [block_size_option](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, "0"); + std::ignore = sess_opts.config_options.AddConfigEntry( + kOrtSessionOptionsQDQMatMulNBitsBlockSize, + std::to_string(block_size_option).c_str()); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorWithBlockSizeOption) { + // Default block_size (0 -> 32) + RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 0); + // Explicit block_size=16 + RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 16); + // Explicit block_size=64 + RunDQMatMulPerTensorWithBlockSize({12, 64}, {64, 16}, 64); + // Explicit block_size=128 + RunDQMatMulPerTensorWithBlockSize({12, 128}, {128, 16}, 128); +} + +// UINT8 per-tensor DQ -> MatMul -> MatMulNBits +// Tests shapes from real models including small dimensions (N=1, N=8). +template +void RunDQMatMulPerTensorUint8Converted(const std::vector& input1_shape, + const std::vector& weight_shape, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer(weight_shape, uint8_t(0), uint8_t(255)); + auto* dq_output = builder.MakeIntermediate(); + + // Scalar scale (per-tensor) + auto* scale_arg = builder.MakeInitializer({}, {0.05f}); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer({}, {uint8_t(128)}); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 0.01 /*per_sample_tolerance - higher due to blockwise accumulation reordering*/, + 5e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulPerTensorUint8ConvertedToMatMulNBits) { + // Typical shapes + RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 768}, 0); + RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 768}, 0); + // Small N=8 + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); + // N=1 (smallest possible column count) + RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 1}, 0); + RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 1}, 0); + // Large N + RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 3072}, 0); +} + +// --------------------------------------------------------------------------- +// DQ -> Gemm tests for MatMulNBits fusion +// --------------------------------------------------------------------------- + +// Input1 +// | DQ (4-bit weight) +// \ / +// Gemm +// | +// output +// Gemm has no bias, equivalent to MatMul. Should fuse to MatMulNBits. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedNoBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_NoBias) { + RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ (4-bit weight) bias (float) +// \ / / +// Gemm +// | +// output +// Gemm has a direct (non-DQ) float bias. Should fuse to MatMulNBits with bias at input 5. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedWithBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + int64_t N = weight_shape[1]; + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f)); + builder.AddNode("Gemm", {input_arg, dq_output, bias_arg}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithBias) { + RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ (4-bit weight) DQ (bias) +// \ / / +// Gemm +// | +// output +// Gemm has a bias from DQ. Weight DQ fused into MatMulNBits, bias DQ stays alive, +// bias DQ output wired to MatMulNBits input 5. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQGemmConvertedWithDQBias(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // Weight DQ + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + std::vector scale_shape = weight_shape; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + int64_t N = weight_shape[1]; + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq_output = builder.MakeIntermediate(); + auto* scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + } + + // Bias DQ (int8 quantized bias -> float) + auto* bias_quantized = builder.MakeInitializer({N}, std::vector(static_cast(N), 5)); + auto* bias_scale = builder.MakeInitializer({}, std::vector{0.1f}); + auto* bias_zp = builder.MakeInitializer({}, std::vector{0}); + auto* bias_dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {bias_quantized, bias_scale, bias_zp}, {bias_dq_output}); + + builder.AddNode("Gemm", {input_arg, dq_output, bias_dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + // Weight DQ removed, bias DQ stays + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 2e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithDQBias) { + RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); + RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); +} + +// Negative test: DQ -> Gemm with transB=1 should NOT be fused. +TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_TransB) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({12, 37}, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // With transB=1, Gemm transposes B at runtime: weight shape [N,K]=[12,37], transposed to [K,N]=[37,12]. + // DQ weight shape is [12,37] (N=12, K=37 after transpose). + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", static_cast(16)), dq_attrs); + auto* weight_arg = builder.MakeInitializer({12, 37}, Int4x2(Int4x2::min_val, 0), Int4x2(Int4x2::max_val, 0)); + auto* scales_arg = builder.MakeInitializer({1, 37}, 8.0f, 12.0f); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + + NodeAttributes gemm_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("transB", static_cast(1)), gemm_attrs); + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}, "", &gemm_attrs); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5, 2e-5); +} + +// Negative test: DQ -> Gemm with alpha != 1.0 should NOT be fused. +TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_Alpha) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({12, 37}, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", static_cast(16)), dq_attrs); + auto* weight_arg = builder.MakeInitializer({37, 12}, Int4x2(Int4x2::min_val, 0), Int4x2(Int4x2::max_val, 0)); + auto* scales_arg = builder.MakeInitializer({3, 12}, 8.0f, 12.0f); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scales_arg}, {dq_output}, "", &dq_attrs); + + NodeAttributes gemm_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("alpha", 2.0f), gemm_attrs); + builder.AddNode("Gemm", {input_arg, dq_output}, {output_arg}, "", &gemm_attrs); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5, 2e-5); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test From bc80857dc72826212db37b55411d9e1ec10fac22 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 19 Mar 2026 19:32:39 +0000 Subject: [PATCH 2/7] Fix scalar zp MakeInitializer assertion failure in per-tensor tests MakeInitializer(shape, T(1,0), T(1,0)) calls Uniform with min==max, which creates uniform_int_distribution(1, 0) (since Uniform uses [min, max)). This triggers assertion failure on both MSVC and GCC. Use explicit data overload instead: MakeInitializer({}, {T(1,0)}) --- .../test/optimizer/qdq_matmulnbits_transformer_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index b2344858dd1a1..b71d04de9a2d2 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -802,7 +802,7 @@ void RunDQMatMulPerTensorConverted(const std::vector& input1_shape, // Scalar scale (per-tensor) auto* scale_arg = builder.MakeInitializer({}, {10.0f}); if constexpr (use_zp) { - auto* zp_arg = builder.MakeInitializer(std::vector{}, T(1, 0), T(1, 0)); + auto* zp_arg = builder.MakeInitializer({}, std::vector{T(1, 0)}); builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); } else { builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); @@ -997,7 +997,7 @@ void RunDQMatMulPerTensorWithBlockSize(const std::vector& input1_shape, auto* scale_arg = builder.MakeInitializer({}, {10.0f}); if constexpr (use_zp) { - auto* zp_arg = builder.MakeInitializer(std::vector{}, T(1, 0), T(1, 0)); + auto* zp_arg = builder.MakeInitializer({}, std::vector{T(1, 0)}); builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}); } else { builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}); From 0165da20b0c3242f16fa6fd77316d1f47bba018b Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 19 Mar 2026 22:06:49 +0000 Subject: [PATCH 3/7] Reduce test matrix sizes to lower memory pressure under ASan Shrink per-tensor test dimensions from 768x768/768x3072 to 96x96/96x384. The same code paths are exercised regardless of matrix size. --- .../qdq_matmulnbits_transformer_test.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index b71d04de9a2d2..6347b27f6b2c5 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -852,8 +852,8 @@ TEST(QDQTransformerTests, DQMatMulPerTensorConvertedToMatMulNBits) { RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); // N=1 (edge case: single column) - RunDQMatMulPerTensorConverted({12, 768}, {768, 1}, 0); - RunDQMatMulPerTensorConverted({12, 768}, {768, 1}, 0); + RunDQMatMulPerTensorConverted({12, 96}, {96, 1}, 0); + RunDQMatMulPerTensorConverted({12, 96}, {96, 1}, 0); } // Per-channel (axis=1) DQ -> MatMul conversion to MatMulNBits @@ -1108,15 +1108,15 @@ void RunDQMatMulPerTensorUint8Converted(const std::vector& input1_shape TEST(QDQTransformerTests, DQMatMulPerTensorUint8ConvertedToMatMulNBits) { // Typical shapes - RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 768}, 0); - RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 768}, 0); + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 96}, 0); + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 96}, 0); // Small N=8 RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); // N=1 (smallest possible column count) - RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 1}, 0); - RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 1}, 0); - // Large N - RunDQMatMulPerTensorUint8Converted({12, 768}, {768, 3072}, 0); + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 1}, 0); + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 1}, 0); + // Larger N + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 384}, 0); } // --------------------------------------------------------------------------- From ea2c6e45374d3febfa927a22517a937905d78541 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 19 Mar 2026 23:16:48 +0000 Subject: [PATCH 4/7] Trim redundant test combinations to reduce ASan memory pressure Remove symmetric type/zp combos where one signed+zp and one unsigned-no-zp call sufficiently covers the code paths. Reduces Run* calls from 172 to 153 (38 fewer inference sessions). --- .../qdq_matmulnbits_transformer_test.cc | 30 ++----------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 6347b27f6b2c5..1e80b3c6671b5 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -840,20 +840,15 @@ void RunDQMatMulPerTensorConverted(const std::vector& input1_shape, } TEST(QDQTransformerTests, DQMatMulPerTensorConvertedToMatMulNBits) { - // Per-tensor int4/uint4 with and without zero-point + // Per-tensor: signed with zp, unsigned without zp RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); - RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); - RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); // With accuracy_level=1 RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 1); - RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 1); // K not divisible by default block_size (32) - RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); // N=1 (edge case: single column) RunDQMatMulPerTensorConverted({12, 96}, {96, 1}, 0); - RunDQMatMulPerTensorConverted({12, 96}, {96, 1}, 0); } // Per-channel (axis=1) DQ -> MatMul conversion to MatMulNBits @@ -921,16 +916,11 @@ void RunDQMatMulPerChannelConverted(const std::vector& input1_shape, } TEST(QDQTransformerTests, DQMatMulPerChannelConvertedToMatMulNBits) { - // Per-channel int4/uint4 with and without zero-point + // Per-channel: signed with zp, unsigned without zp RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); - RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); - RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); - // With accuracy_level=1 - RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 1); // K not divisible by default block_size (32) RunDQMatMulPerChannelConverted({12, 37}, {37, 16}, 0); - RunDQMatMulPerChannelConverted({12, 37}, {37, 16}, 0); } // Negative test: per-axis axis=0 with 1D scale should NOT fuse @@ -980,7 +970,6 @@ void RunDQMatMulPerAxisAxis0NotConverted(const std::vector& input1_shap TEST(QDQTransformerTests, DQMatMulPerAxisAxis0NotConvertedToMatMulNBits) { RunDQMatMulPerAxisAxis0NotConverted({12, 32}, {32, 16}); - RunDQMatMulPerAxisAxis0NotConverted({12, 32}, {32, 16}); } // Per-tensor DQ -> MatMul with configurable block_size session option @@ -1107,16 +1096,9 @@ void RunDQMatMulPerTensorUint8Converted(const std::vector& input1_shape } TEST(QDQTransformerTests, DQMatMulPerTensorUint8ConvertedToMatMulNBits) { - // Typical shapes RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 96}, 0); - RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 96}, 0); - // Small N=8 - RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); - // N=1 (smallest possible column count) + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 1}, 0); - RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 1}, 0); - // Larger N - RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 384}, 0); } // --------------------------------------------------------------------------- @@ -1190,8 +1172,6 @@ RunDQGemmConvertedNoBias(const std::vector& input1_shape, TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_NoBias) { RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); } @@ -1264,8 +1244,6 @@ RunDQGemmConvertedWithBias(const std::vector& input1_shape, TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithBias) { RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); } @@ -1346,8 +1324,6 @@ RunDQGemmConvertedWithDQBias(const std::vector& input1_shape, TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithDQBias) { RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); } From 11b46d0c5e9f0cb068d55914e5d6a2b4f76ec5cf Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Fri, 20 Mar 2026 01:39:06 +0000 Subject: [PATCH 5/7] Reduce test matrix to fix ASan OOM in onnxruntime_test_all Trim redundant type/zp/accuracy_level permutations in negative tests (NonConstDQ, FirstDQInput, ShapeMismatch) where the rejection logic doesn't depend on these parameters. Also trim new per-tensor, per-channel, Gemm, and uint8 tests to representative combinations. Session-creating invocations: 153 -> 63 (59% reduction). All 34 DQMatMul/DQGemm tests still pass. --- .../qdq_matmulnbits_transformer_test.cc | 111 +----------------- 1 file changed, 5 insertions(+), 106 deletions(-) diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 1e80b3c6671b5..03005e3a07386 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -114,43 +114,15 @@ RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { // DQ contrib op schema is not updated to support blocked quantization + // Rejection doesn't depend on type/zp/accuracy_level — keep representative combos only. RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); } // Input2 @@ -224,42 +196,13 @@ RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1, DefaultCudaExecutionProvider()); } // Input1 @@ -353,52 +296,27 @@ TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { - // DQ contrib op schema is not updated to support blocked quantization + // One representative type combo per rejection scenario (type doesn't affect rejection logic). // block size too small RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); // block size not 2's power - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); // not axis 0 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch_Cuda) { - // DQ contrib op schema is not updated to support blocked quantization + // One representative type combo per rejection scenario. // block size too small RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0, DefaultCudaExecutionProvider()); // block size not 2's power - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); - ; - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0, DefaultCudaExecutionProvider()); // not axis 0 - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0, DefaultCudaExecutionProvider()); // not rank 2 - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); - RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); } // Input1 @@ -840,15 +758,9 @@ void RunDQMatMulPerTensorConverted(const std::vector& input1_shape, } TEST(QDQTransformerTests, DQMatMulPerTensorConvertedToMatMulNBits) { - // Per-tensor: signed with zp, unsigned without zp + // Per-tensor: cover both types and a non-divisible K case. RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); - RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 0); - // With accuracy_level=1 - RunDQMatMulPerTensorConverted({12, 32}, {32, 16}, 1); - // K not divisible by default block_size (32) RunDQMatMulPerTensorConverted({12, 37}, {37, 16}, 0); - // N=1 (edge case: single column) - RunDQMatMulPerTensorConverted({12, 96}, {96, 1}, 0); } // Per-channel (axis=1) DQ -> MatMul conversion to MatMulNBits @@ -916,10 +828,6 @@ void RunDQMatMulPerChannelConverted(const std::vector& input1_shape, } TEST(QDQTransformerTests, DQMatMulPerChannelConvertedToMatMulNBits) { - // Per-channel: signed with zp, unsigned without zp - RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); - RunDQMatMulPerChannelConverted({12, 32}, {32, 16}, 0); - // K not divisible by default block_size (32) RunDQMatMulPerChannelConverted({12, 37}, {37, 16}, 0); } @@ -1036,10 +944,6 @@ TEST(QDQTransformerTests, DQMatMulPerTensorWithBlockSizeOption) { RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 0); // Explicit block_size=16 RunDQMatMulPerTensorWithBlockSize({12, 32}, {32, 16}, 16); - // Explicit block_size=64 - RunDQMatMulPerTensorWithBlockSize({12, 64}, {64, 16}, 64); - // Explicit block_size=128 - RunDQMatMulPerTensorWithBlockSize({12, 128}, {128, 16}, 128); } // UINT8 per-tensor DQ -> MatMul -> MatMulNBits @@ -1096,9 +1000,7 @@ void RunDQMatMulPerTensorUint8Converted(const std::vector& input1_shape } TEST(QDQTransformerTests, DQMatMulPerTensorUint8ConvertedToMatMulNBits) { - RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 96}, 0); - RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); - RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 1}, 0); + RunDQMatMulPerTensorUint8Converted({12, 96}, {96, 8}, 0); } // --------------------------------------------------------------------------- @@ -1172,7 +1074,6 @@ RunDQGemmConvertedNoBias(const std::vector& input1_shape, TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_NoBias) { RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedNoBias({12, 37}, {37, 12}, 0, 16, 0); } // Input1 @@ -1244,7 +1145,6 @@ RunDQGemmConvertedWithBias(const std::vector& input1_shape, TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithBias) { RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedWithBias({12, 37}, {37, 12}, 0, 16, 0); } // Input1 @@ -1324,7 +1224,6 @@ RunDQGemmConvertedWithDQBias(const std::vector& input1_shape, TEST(QDQTransformerTests, DQGemmConvertedToMatMulNBits_WithDQBias) { RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); - RunDQGemmConvertedWithDQBias({12, 37}, {37, 12}, 0, 16, 0); } // Negative test: DQ -> Gemm with transB=1 should NOT be fused. From 61491e56cc7732c21d3aba5eefe2850f2bec1f91 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Mon, 23 Mar 2026 21:57:44 +0000 Subject: [PATCH 6/7] Use weight_tensor_proto dims instead of weight_arg->Shape() Address Copilot review: derive K and N from the already-loaded weight_tensor_proto->dims() rather than weight_arg->Shape(), which avoids an unnecessary NodeArg::Shape() dereference. Also adds an explicit rank >= 2 check on the tensor proto. --- .../qdq_transformer/selectors_actions/qdq_actions.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 05d7dde81b2e6..f320d94f30a1e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -140,8 +140,10 @@ Status TransposeDQWeightsForMatMulNBits( graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); } - auto K = weight_arg->Shape()->dim(0).dim_value(); - auto N = weight_arg->Shape()->dim(1).dim_value(); + ORT_RETURN_IF_NOT(weight_tensor_proto->dims_size() >= 2, + "Weight tensor for node ", dq_node.Name(), " must be at least 2D."); + auto K = weight_tensor_proto->dims(0); + auto N = weight_tensor_proto->dims(1); auto block_size = effective_block_size; int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); auto bits = DQWeightBits(dt_weight); From 1a51c7a27e26794cb5b70cb56bcc8022bfb45986 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Mon, 23 Mar 2026 23:00:29 +0000 Subject: [PATCH 7/7] null guard --- .../optimizer/qdq_transformer/selectors_actions/qdq_actions.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index f320d94f30a1e..b9d7e898157bd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -621,6 +621,8 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) const auto* dq_node = runtime_state.selected_nodes.Input(0); const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + ORT_ENFORCE(weight_shape != nullptr && weight_shape->dim_size() >= 2, + "Weight shape unavailable for DQ node ", dq_node->Name()); utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes);