diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 1ea147a0079cc..f73a94ae686d0 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -385,6 +385,11 @@ static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm"; // If not provided, default is 4. static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; +// Enable the DQ->MatMulNBits fusion graph transformer. +// "0": disabled (default). "1": enabled. +// This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered. +static const char* const kOrtSessionOptionsEnableDQMatMulNBitsFusion = "session.enable_dq_matmulnbits_fusion"; + // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME // Meant to be used with SetEpDynamicOptions // Specify the type of workload for this session. diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc new file mode 100644 index 0000000000000..f9ae13808cf2c --- /dev/null +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc @@ -0,0 +1,848 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/dq_matmulnbits_fusion.h" + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/constants.h" +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +#include +#include +#include +#include + +namespace onnxruntime { + +namespace { + +// --------------------------------------------------------------------------- +// Utility helpers +// --------------------------------------------------------------------------- + +bool IsUniformPackedUint4Value(const Initializer& init, uint8_t expected_nibble) { + if (init.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + return false; + } + + const size_t values_count = static_cast(init.size()); + if (values_count == 0) { + return false; + } + + const auto packed = init.DataAsByteSpan(); + const uint8_t expected = static_cast(expected_nibble & 0x0F); + for (size_t i = 0; i < values_count; ++i) { + const uint8_t byte = packed[i / 2]; + const uint8_t value = (i % 2 == 0) ? (byte & 0x0F) : ((byte >> 4) & 0x0F); + if (value != expected) { + return false; + } + } + + return true; +} + +bool HasRank2Shape(const ONNX_NAMESPACE::TensorProto& tp, int64_t dim0, int64_t dim1) { + return tp.dims_size() == 2 && tp.dims(0) == dim0 && tp.dims(1) == dim1; +} + +uint8_t GetPackedUint4Element(const uint8_t* packed, size_t index, size_t num_elements) { + ORT_ENFORCE(index < num_elements, "GetPackedUint4Element: index ", index, + " out of bounds (num_elements=", num_elements, ")"); + const uint8_t packed_byte = packed[index / 2]; + return (index % 2 == 0) ? static_cast(packed_byte & 0x0F) + : static_cast((packed_byte >> 4) & 0x0F); +} + +void PackUint4Rows(const Initializer& src, int64_t rows, int64_t cols, uint8_t* dst) { + const int64_t row_bytes = (cols + 1) / 2; + const size_t dst_bytes = SafeInt(rows) * row_bytes; + const size_t total_elements = SafeInt(rows) * cols; + memset(dst, 0, dst_bytes); + + const auto src_packed = src.DataAsByteSpan(); + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + const size_t src_index = SafeInt(r) * cols + c; + const uint8_t value = GetPackedUint4Element(src_packed.data(), src_index, total_elements); + + const size_t dst_index = SafeInt(r) * row_bytes + c / 2; + if ((c & 1) == 0) { + dst[dst_index] = value; + } else { + dst[dst_index] = static_cast(dst[dst_index] | (value << 4)); + } + } + } +} + +// Transpose and pack UINT4 weights from DQ axis=0 layout [K, N] to MatMulNBits layout [N, k_blocks, blob_size]. +// Source: row-major UINT4 with quantization along K (axis=0), shape [K, N]. +// The nibble ordering follows ONNX UINT4 convention: even indices in the low nibble, +// odd indices in the high nibble of each byte. +// Dest: UINT8 [N, k_blocks, block_size/2] where each byte packs two 4-bit weights. +void TransposePackWeightsAxis0( + const uint8_t* src_packed, int64_t K, int64_t N, int64_t block_size, + uint8_t* dst) { + const int64_t k_blocks = (K + block_size - 1) / block_size; + const int64_t blob_size = block_size / 2; + const size_t dst_bytes = SafeInt(N) * k_blocks * blob_size; + const size_t total_elements = SafeInt(K) * N; + memset(dst, 0, dst_bytes); + + for (int64_t n = 0; n < N; ++n) { + for (int64_t k = 0; k < K; ++k) { + const size_t src_index = SafeInt(k) * N + n; + const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); + + const int64_t kb = k / block_size; + const int64_t off = k % block_size; + const size_t dst_byte = SafeInt(n) * k_blocks * blob_size + kb * blob_size + off / 2; + if (off % 2 == 0) { + dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); + } else { + dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); + } + } + } +} + +// Transpose and pack UINT4 zero points from DQ axis=0 layout [k_blocks, N] to +// MatMulNBits layout UINT8 [N, ceil(k_blocks/2)]. +void TransposePackZPAxis0( + const uint8_t* src_packed, int64_t k_blocks, int64_t N, + uint8_t* dst) { + const int64_t zp_bytes_per_n = (k_blocks + 1) / 2; + const size_t dst_bytes = SafeInt(N) * zp_bytes_per_n; + const size_t total_elements = SafeInt(k_blocks) * N; + memset(dst, 0, dst_bytes); + + for (int64_t n = 0; n < N; ++n) { + for (int64_t kb = 0; kb < k_blocks; ++kb) { + const size_t src_index = SafeInt(kb) * N + n; + const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); + + const size_t dst_byte = SafeInt(n) * zp_bytes_per_n + kb / 2; + if (kb % 2 == 0) { + dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); + } else { + dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); + } + } + } +} + +// --------------------------------------------------------------------------- +// Match structs +// --------------------------------------------------------------------------- + +struct FusionMatch { + NodeIndex matmul_idx; + std::optional cast_idx; + NodeIndex transpose_idx; + NodeIndex reshape_idx; + NodeIndex dq_idx; +}; + +struct DirectDQMatch { + NodeIndex matmul_idx; + NodeIndex dq_idx; +}; + +// --------------------------------------------------------------------------- +// Shared Gemm validation (alpha=1, beta=1, transA=0, transB=0, bias 1-D [N]) +// --------------------------------------------------------------------------- + +bool ValidateGemmForFusion(const Node& gemm_node, int64_t N) { + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm_node, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) + return false; + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm_node, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) + return false; + if (const auto* trans_a = graph_utils::GetNodeAttribute(gemm_node, "transA"); + trans_a && trans_a->i() != 0) + return false; + if (const auto* trans_b = graph_utils::GetNodeAttribute(gemm_node, "transB"); + trans_b && trans_b->i() != 0) + return false; + + const auto& inputs = gemm_node.InputDefs(); + if (inputs.size() > 2 && inputs[2] && inputs[2]->Exists()) { + const auto* bias_shape = inputs[2]->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1 || + !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != N) + return false; + } + return true; +} + +// --------------------------------------------------------------------------- +// Pattern 1 matching: DQ -> Reshape -> Transpose -> [Cast] -> MatMul/Gemm +// --------------------------------------------------------------------------- + +std::vector CollectReshapeTransposeMatches( + Graph& graph, + const std::vector& node_topology_list, + const logging::Logger& logger) { + std::vector matches; + + for (auto node_index : node_topology_list) { + auto* node = graph.GetNode(node_index); + if (!node) continue; + + if (node->OpType() != "MatMul" && node->OpType() != "Gemm") continue; + + const auto& mm_inputs = node->InputDefs(); + if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; + + const Node* cast_node = nullptr; + const Node* transpose_node = graph.GetProducerNode(mm_inputs[1]->Name()); + if (transpose_node && transpose_node->OpType() == "Cast") { + cast_node = transpose_node; + if (cast_node->GetOutputEdgesCount() != 1) continue; + const auto& cast_inputs = cast_node->InputDefs(); + if (cast_inputs.empty() || !cast_inputs[0] || !cast_inputs[0]->Exists()) continue; + transpose_node = graph.GetProducerNode(cast_inputs[0]->Name()); + } + + if (!transpose_node || transpose_node->OpType() != "Transpose") continue; + if (transpose_node->GetOutputEdgesCount() != 1) continue; + + const auto& tp_inputs = transpose_node->InputDefs(); + if (tp_inputs.empty() || !tp_inputs[0] || !tp_inputs[0]->Exists()) continue; + const Node* reshape_node = graph.GetProducerNode(tp_inputs[0]->Name()); + if (!reshape_node || reshape_node->OpType() != "Reshape") continue; + if (reshape_node->GetOutputEdgesCount() != 1) continue; + + const auto& reshape_inputs = reshape_node->InputDefs(); + if (reshape_inputs.empty() || !reshape_inputs[0] || !reshape_inputs[0]->Exists()) continue; + const Node* dq_node = graph.GetProducerNode(reshape_inputs[0]->Name()); + if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue; + if (dq_node->GetOutputEdgesCount() != 1) continue; + + const auto& dq_attrs = dq_node->GetAttributes(); + { + auto it = dq_attrs.find("axis"); + if (it == dq_attrs.end() || it->second.i() != 2) continue; + } + int64_t block_size = 0; + { + auto it = dq_attrs.find("block_size"); + if (it == dq_attrs.end()) continue; + block_size = it->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) continue; + } + + const auto* weight_arg = dq_node->InputDefs()[0]; + if (!weight_arg || !weight_arg->Exists()) continue; + const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); + if (!weight_const_tp) continue; + if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (weight_const_tp->dims_size() != 3) continue; + const int64_t N = weight_const_tp->dims(0); + const int64_t blocks = weight_const_tp->dims(1); + const int64_t bs_dim = weight_const_tp->dims(2); + if (N <= 0 || blocks <= 0 || bs_dim <= 0) continue; + if (bs_dim != block_size) continue; + const int64_t K = SafeInt(blocks) * bs_dim; + + const auto* scale_arg = dq_node->InputDefs()[1]; + if (!scale_arg || !scale_arg->Exists()) continue; + const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); + if (!scale_const_tp) continue; + int32_t dt_scale = scale_const_tp->data_type(); + if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + + const auto* a_arg = mm_inputs[0]; + if (!a_arg || !a_arg->TypeAsProto()) continue; + int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_a != dt_scale) continue; + + const auto* reshape_shape_arg = + reshape_node->InputDefs().size() > 1 ? reshape_node->InputDefs()[1] : nullptr; + if (!reshape_shape_arg || !reshape_shape_arg->Exists()) continue; + const auto* reshape_shape_tp = graph.GetConstantInitializer(reshape_shape_arg->Name(), true); + if (!reshape_shape_tp) continue; + + Initializer reshape_shape_init(graph, *reshape_shape_tp, graph.ModelPath()); + if (reshape_shape_init.size() != 2) continue; + + int64_t reshape_dim0 = 0; + int64_t reshape_dim1 = 0; + if (reshape_shape_init.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) { + const auto* shape_data = reshape_shape_init.data(); + reshape_dim0 = shape_data[0]; + reshape_dim1 = shape_data[1]; + } else if (reshape_shape_init.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) { + const auto* shape_data = reshape_shape_init.data(); + reshape_dim0 = shape_data[0]; + reshape_dim1 = shape_data[1]; + } else { + continue; + } + + auto resolve_reshape_dim = [](int64_t dim, int64_t expected) -> std::optional { + if (dim == expected || dim == 0 || dim == -1) { + return expected; + } + return std::nullopt; + }; + const auto resolved_reshape_dim0 = resolve_reshape_dim(reshape_dim0, N); + const auto resolved_reshape_dim1 = resolve_reshape_dim(reshape_dim1, K); + if (!resolved_reshape_dim0 || !resolved_reshape_dim1 || + *resolved_reshape_dim0 != N || *resolved_reshape_dim1 != K) { + continue; + } + + if (const auto* perm_attr = graph_utils::GetNodeAttribute(*transpose_node, "perm")) { + if (perm_attr->ints_size() != 2 || perm_attr->ints(0) != 1 || perm_attr->ints(1) != 0) { + continue; + } + } + + if (const auto* b_shape = mm_inputs[1]->Shape(); b_shape && b_shape->dim_size() == 2 && + utils::HasDimValue(b_shape->dim(0)) && utils::HasDimValue(b_shape->dim(1)) && + (b_shape->dim(0).dim_value() != K || b_shape->dim(1).dim_value() != N)) { + continue; + } + + if (const auto* a_shape = mm_inputs[0] ? mm_inputs[0]->Shape() : nullptr; + a_shape && a_shape->dim_size() >= 1) { + const int last_a_dim_idx = a_shape->dim_size() - 1; + if (utils::HasDimValue(a_shape->dim(last_a_dim_idx)) && + a_shape->dim(last_a_dim_idx).dim_value() != K) { + continue; + } + } + + const auto* y_shape = node->OutputDefs().empty() ? nullptr : node->OutputDefs()[0]->Shape(); + if (y_shape && y_shape->dim_size() >= 1) { + const int last_y_dim_idx = y_shape->dim_size() - 1; + if (utils::HasDimValue(y_shape->dim(last_y_dim_idx)) && + y_shape->dim(last_y_dim_idx).dim_value() != N) { + continue; + } + } + + if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; + + if (cast_node) { + const auto* cast_in = cast_node->InputDefs().empty() ? nullptr : cast_node->InputDefs()[0]; + const auto* cast_out = cast_node->OutputDefs().empty() ? nullptr : cast_node->OutputDefs()[0]; + if (!cast_in || !cast_out || !cast_in->TypeAsProto() || !cast_out->TypeAsProto()) continue; + if (cast_in->TypeAsProto()->tensor_type().elem_type() != + cast_out->TypeAsProto()->tensor_type().elem_type()) { + continue; + } + } + + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + if (has_zp) { + const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); + if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + } + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched pattern at MatMul node '" + << node->Name() << "'"; + + matches.push_back({node->Index(), + cast_node ? std::optional(cast_node->Index()) : std::nullopt, + transpose_node->Index(), + reshape_node->Index(), dq_node->Index()}); + } + + return matches; +} + +// --------------------------------------------------------------------------- +// Pattern 2 matching: direct DQ(axis=0, 2D UINT4) -> MatMul/Gemm +// --------------------------------------------------------------------------- + +std::vector CollectDirectDQMatches( + Graph& graph, + const std::vector& node_topology_list, + const std::unordered_set& skip_indices, + const logging::Logger& logger) { + std::vector direct_matches; + + for (auto node_index : node_topology_list) { + auto* node = graph.GetNode(node_index); + if (!node) continue; + + if (node->OpType() != "MatMul" && node->OpType() != "Gemm") continue; + if (skip_indices.count(node->Index())) continue; + + const auto& mm_inputs = node->InputDefs(); + if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; + + const Node* dq_node = graph.GetProducerNode(mm_inputs[1]->Name()); + if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue; + if (dq_node->GetOutputEdgesCount() != 1) continue; + + const auto& dq_attrs = dq_node->GetAttributes(); + { + auto it = dq_attrs.find("axis"); + if (it == dq_attrs.end() || it->second.i() != 0) continue; + } + int64_t block_size = 0; + { + auto it = dq_attrs.find("block_size"); + if (it == dq_attrs.end()) continue; + block_size = it->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) continue; + } + + const auto* weight_arg = dq_node->InputDefs()[0]; + if (!weight_arg || !weight_arg->Exists()) continue; + const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); + if (!weight_const_tp) continue; + if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (weight_const_tp->dims_size() != 2) continue; + const int64_t K = weight_const_tp->dims(0); + const int64_t N = weight_const_tp->dims(1); + if (K <= 0 || N <= 0 || K % block_size != 0) continue; + const int64_t k_blocks = K / block_size; + + const auto* scale_arg = dq_node->InputDefs()[1]; + if (!scale_arg || !scale_arg->Exists()) continue; + const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); + if (!scale_const_tp) continue; + int32_t dt_scale = scale_const_tp->data_type(); + if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + if (!HasRank2Shape(*scale_const_tp, k_blocks, N)) continue; + + const auto* a_arg = mm_inputs[0]; + if (!a_arg || !a_arg->TypeAsProto()) continue; + int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_a != dt_scale) continue; + + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + if (has_zp) { + const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); + if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (!HasRank2Shape(*zp_const_tp, k_blocks, N)) continue; + } + + if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched direct DQ->MatMul pattern at node '" + << node->Name() << "' (K=" << K << ", N=" << N << ", block_size=" << block_size << ")"; + direct_matches.push_back({node->Index(), dq_node->Index()}); + } + + return direct_matches; +} + +// --------------------------------------------------------------------------- +// Pattern 1 rewriting: DQ+Reshape+Transpose+[Cast]+MatMul/Gemm -> MatMulNBits +// --------------------------------------------------------------------------- + +void ApplyReshapeTransposeFusions( + Graph& graph, + const std::vector& matches, + int64_t accuracy_level, + bool& modified, + const logging::Logger& logger) { + for (const auto& match : matches) { + const Node* mm_node = graph.GetNode(match.matmul_idx); + const Node* cast_node = match.cast_idx ? graph.GetNode(*match.cast_idx) : nullptr; + const Node* tp_node = graph.GetNode(match.transpose_idx); + const Node* dq_node = graph.GetNode(match.dq_idx); + const Node* reshape_node = graph.GetNode(match.reshape_idx); + if (!mm_node || !tp_node || !dq_node || !reshape_node || + (match.cast_idx && !cast_node)) { + continue; + } + + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + + const auto& dq_attrs = dq_node->GetAttributes(); + const int64_t block_size = dq_attrs.at("block_size").i(); + + const ONNX_NAMESPACE::TensorProto* weight_tp = nullptr; + if (!graph.GetInitializedTensor(weight_arg->Name(), weight_tp) || !weight_tp) continue; + const ONNX_NAMESPACE::TensorProto* scale_tp = nullptr; + if (!graph.GetInitializedTensor(scale_arg->Name(), scale_tp) || !scale_tp) continue; + const ONNX_NAMESPACE::TensorProto* zp_tp = nullptr; + if (has_zp) { + if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; + } + + if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || + weight_tp->dims_size() != 3) { + continue; + } + + const int64_t N = weight_tp->dims(0); + const int64_t quant_num = weight_tp->dims(1); + const int64_t bs_dim = weight_tp->dims(2); + if (N <= 0 || quant_num <= 0 || bs_dim <= 0 || bs_dim != block_size) continue; + const int64_t K = SafeInt(quant_num) * bs_dim; + const int64_t blob_bytes = (block_size + 1) / 2; + + Initializer weight_src(graph, *weight_tp, graph.ModelPath()); + Initializer scale_src(graph, *scale_tp, graph.ModelPath()); + if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum( + ONNX_NAMESPACE::TensorProto_DataType_UINT8) + ->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum( + scale_src.data_type()) + ->GetElementType(); + + auto cpu_allocator = CPUAllocator::DefaultInstance(); + + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb"); + auto weight_dst = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb"); + const int64_t scale_size = (TensorShape{N, quant_num}).Size(); + if (scale_src.size() != static_cast(scale_size)) continue; + auto scale_dst = Tensor(scale_type, TensorShape{scale_size}, cpu_allocator); + + std::string zp_dst_name; + std::optional zp_dst; + const int64_t zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); + + bool elide_default_uint4_zp8_input = false; + std::optional zp_src; + + const auto weight_bytes = weight_src.DataAsByteSpan(); + if (weight_bytes.size() != static_cast(weight_dst.SizeInBytes())) continue; + memcpy(weight_dst.MutableDataRaw(), weight_bytes.data(), weight_bytes.size()); + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + memcpy(scale_dst.MutableData(), scale_src.data(), + static_cast(scale_size) * sizeof(float)); + } else { + memcpy(scale_dst.MutableData(), scale_src.data(), + static_cast(scale_size) * sizeof(MLFloat16)); + } + + if (zp_tp) { + zp_src.emplace(graph, *zp_tp, graph.ModelPath()); + if (zp_src->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (zp_src->size() != static_cast(N * quant_num)) continue; + + const bool is_default_uint4_8 = + IsUniformPackedUint4Value(*zp_src, /*expected_nibble*/ 8); + if (is_default_uint4_8) { + elide_default_uint4_zp8_input = true; + } else { + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + PackUint4Rows(*zp_src, N, quant_num, zp_dst->MutableData()); + } + } else { + // DequantizeLinear default zero-point for uint4 is 0, while MatMulNBits + // default is 8. Emit explicit zeros to preserve semantics. + zp_dst_name = graph.GenerateNodeArgName("fused_DQ_zp_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); + } + + auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); + auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); + std::optional zp_mnb_tp; + if (zp_dst && !elide_default_uint4_zp8_input) { + zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); + } + + NodeAttributes mnb_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); + + std::vector mnb_inputs; + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); + if (zp_mnb_tp) { + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); + } + + // MatMulNBits input layout: 0:A, 1:B, 2:scales, 3:zero_points(opt), 4:g_idx(opt), 5:bias(opt) + bool fused_with_bias = false; + if (mm_node->OpType() == "Gemm" && + mm_node->InputDefs().size() > 2 && + mm_node->InputDefs()[2] && + mm_node->InputDefs()[2]->Exists()) { + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + while (mnb_inputs.size() < 5) { + mnb_inputs.push_back(&empty_arg); + } + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[2])); + fused_with_bias = true; + } + + std::vector mnb_outputs; + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + + auto& mnb_node = graph.AddNode( + graph.GenerateNodeName("DQFusedMatMulNBits"), + "MatMulNBits", + "Fused from DQ+Reshape+Transpose+MatMul", + mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); + mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); + graph.RemoveNode(match.matmul_idx); + + if (match.cast_idx && graph.GetNode(*match.cast_idx)) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(*match.cast_idx)); + graph.RemoveNode(*match.cast_idx); + } + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.transpose_idx)); + graph.RemoveNode(match.transpose_idx); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.reshape_idx)); + graph.RemoveNode(match.reshape_idx); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx)); + graph.RemoveNode(match.dq_idx); + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused DQ+Reshape+Transpose" + << (match.cast_idx ? "+Cast" : "") + << "+MatMul/Gemm -> MatMulNBits" + << (fused_with_bias ? " (bias preserved)" : "") + << (elide_default_uint4_zp8_input ? " (default UINT4 zp8 elided)" : ""); + modified = true; + } +} + +// --------------------------------------------------------------------------- +// Pattern 2 rewriting: direct DQ(axis=0) + MatMul/Gemm -> MatMulNBits +// --------------------------------------------------------------------------- + +void ApplyDirectDQFusions( + Graph& graph, + const std::vector& matches, + int64_t accuracy_level, + bool& modified, + const logging::Logger& logger) { + for (const auto& match : matches) { + const Node* mm_node = graph.GetNode(match.matmul_idx); + const Node* dq_node = graph.GetNode(match.dq_idx); + if (!mm_node || !dq_node) continue; + + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + + const auto& dq_attrs = dq_node->GetAttributes(); + const int64_t block_size = dq_attrs.at("block_size").i(); + + const ONNX_NAMESPACE::TensorProto* weight_tp = nullptr; + if (!graph.GetInitializedTensor(weight_arg->Name(), weight_tp) || !weight_tp) continue; + const ONNX_NAMESPACE::TensorProto* scale_tp = nullptr; + if (!graph.GetInitializedTensor(scale_arg->Name(), scale_tp) || !scale_tp) continue; + const ONNX_NAMESPACE::TensorProto* zp_tp = nullptr; + if (has_zp) { + if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; + } + + if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || + weight_tp->dims_size() != 2) continue; + + const int64_t K = weight_tp->dims(0); + const int64_t N = weight_tp->dims(1); + if (K <= 0 || N <= 0 || block_size <= 0 || K % block_size != 0) continue; + const int64_t k_blocks = K / block_size; + const int64_t blob_bytes = block_size / 2; + if (!HasRank2Shape(*scale_tp, k_blocks, N)) continue; + if (zp_tp && !HasRank2Shape(*zp_tp, k_blocks, N)) continue; + + Initializer weight_src(graph, *weight_tp, graph.ModelPath()); + const size_t required_weight_bytes = SafeInt(N) * k_blocks * blob_bytes; + if (weight_src.DataAsByteSpan().size() < required_weight_bytes) continue; + Initializer scale_src(graph, *scale_tp, graph.ModelPath()); + if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum( + ONNX_NAMESPACE::TensorProto_DataType_UINT8) + ->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum( + scale_src.data_type()) + ->GetElementType(); + auto cpu_allocator = CPUAllocator::DefaultInstance(); + + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb"); + auto weight_dst = Tensor(uint8_type, TensorShape{N, k_blocks, blob_bytes}, cpu_allocator); + TransposePackWeightsAxis0(weight_src.DataAsByteSpan().data(), K, N, block_size, + weight_dst.MutableData()); + + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb"); + const int64_t scale_count = SafeInt(N) * k_blocks; + if (scale_src.size() != static_cast(scale_count)) continue; + auto scale_dst = Tensor(scale_type, TensorShape{scale_count}, cpu_allocator); + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + const float* src = scale_src.data(); + float* dst = scale_dst.MutableData(); + for (int64_t n = 0; n < N; ++n) + for (int64_t kb = 0; kb < k_blocks; ++kb) + dst[n * k_blocks + kb] = src[kb * N + n]; + } else { + const MLFloat16* src = scale_src.data(); + MLFloat16* dst = scale_dst.MutableData(); + for (int64_t n = 0; n < N; ++n) + for (int64_t kb = 0; kb < k_blocks; ++kb) + dst[n * k_blocks + kb] = src[kb * N + n]; + } + + std::string zp_dst_name; + std::optional zp_dst; + const int64_t zp_bytes_total = SafeInt(N) * ((k_blocks + 1) / 2); + + bool elide_zp = false; + + if (zp_tp) { + Initializer zp_src(graph, *zp_tp, graph.ModelPath()); + if (zp_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (zp_src.size() != static_cast(k_blocks * N)) continue; + + if (IsUniformPackedUint4Value(zp_src, 8)) { + elide_zp = true; + } else { + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); + TransposePackZPAxis0(zp_src.DataAsByteSpan().data(), k_blocks, N, + zp_dst->MutableData()); + } + } else { + // DQ default ZP for UINT4 is 0, MatMulNBits default is 8. Emit explicit zeros. + zp_dst_name = graph.GenerateNodeArgName("direct_DQ_zp_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); + memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); + } + + auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); + auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); + std::optional zp_mnb_tp; + if (zp_dst && !elide_zp) { + zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); + } + + NodeAttributes mnb_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); + + std::vector mnb_inputs; + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); + if (zp_mnb_tp) { + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); + } + + bool fused_with_bias = false; + if (mm_node->OpType() == "Gemm" && + mm_node->InputDefs().size() > 2 && + mm_node->InputDefs()[2] && + mm_node->InputDefs()[2]->Exists()) { + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + while (mnb_inputs.size() < 5) { + mnb_inputs.push_back(&empty_arg); + } + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[2])); + fused_with_bias = true; + } + + std::vector mnb_outputs; + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + + auto& mnb_node = graph.AddNode( + graph.GenerateNodeName("DirectDQFusedMatMulNBits"), + "MatMulNBits", + "Fused from direct DQ(axis=0)+MatMul", + mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); + mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); + graph.RemoveNode(match.matmul_idx); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx)); + graph.RemoveNode(match.dq_idx); + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused direct DQ(axis=0)+MatMul/Gemm -> MatMulNBits" + << " (K=" << K << ", N=" << N << ", block_size=" << block_size << ")" + << (fused_with_bias ? " (bias preserved)" : "") + << (elide_zp ? " (default UINT4 zp8 elided)" : ""); + modified = true; + } +} + +} // namespace + +// --------------------------------------------------------------------------- +// DQMatMulNBitsFusion public interface +// --------------------------------------------------------------------------- + +DQMatMulNBitsFusion::DQMatMulNBitsFusion( + int64_t accuracy_level, + const InlinedHashSet& compatible_eps) + : GraphTransformer("DQMatMulNBitsFusion", compatible_eps), + accuracy_level_(accuracy_level) { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, + "MatMulNBits accuracy level must be between 0 and 4"); +} + +Status DQMatMulNBitsFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node = graph.GetNode(node_index); + if (!node) continue; + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); + } + + auto matches = CollectReshapeTransposeMatches(graph, node_topology_list, logger); + + std::unordered_set matched_matmul_indices; + for (const auto& m : matches) { + matched_matmul_indices.insert(m.matmul_idx); + } + + auto direct_matches = CollectDirectDQMatches(graph, node_topology_list, + matched_matmul_indices, logger); + + ApplyReshapeTransposeFusions(graph, matches, accuracy_level_, modified, logger); + ApplyDirectDQFusions(graph, direct_matches, accuracy_level_, modified, logger); + + return Status::OK(); +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h new file mode 100644 index 0000000000000..97c0debd760c0 --- /dev/null +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +// Fuses DequantizeLinear chains back into a single MatMulNBits contrib op. +// +// Supported patterns: +// Pattern 1: DQ(3D, UINT4, axis=2) -> Reshape(2D) -> Transpose([1,0]) +// -> [optional Cast] -> MatMul/Gemm => MatMulNBits +// Pattern 2: DQ(2D, UINT4, axis=0) -> MatMul/Gemm => MatMulNBits +// +// These patterns are produced when a quantized model goes through external +// toolchains that lower MatMulNBits to DQ + reshape/transpose + MatMul +// primitives, and then re-import the graph into ORT. +class DQMatMulNBitsFusion : public GraphTransformer { + public: + explicit DQMatMulNBitsFusion( + int64_t accuracy_level = 4, + const InlinedHashSet& compatible_eps = {}); + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const override; + + int64_t accuracy_level_; +}; + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index fdd4f5aa27862..4edabbe6058ab 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -18,6 +18,8 @@ #if !defined(ORT_MINIMAL_BUILD) +#include "core/optimizer/dq_matmulnbits_fusion.h" + #include "core/mlas/inc/mlas.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/bias_dropout_fusion.h" @@ -274,6 +276,26 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } +#if !defined(DISABLE_CONTRIB_OPS) + { + const bool enable_dq_matmulnbits_fusion = + session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "0") == "1"; + if (enable_dq_matmulnbits_fusion && !disable_quant_qdq) { + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); + transformers.emplace_back(std::make_unique( + qdq_matmulnbits_accuracy_level)); + } + } +#else + ORT_ENFORCE(session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "0") != "1", + "DQ->MatMulNBits fusion requires contrib ops but DISABLE_CONTRIB_OPS is defined"); +#endif + // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c00d63d0be8a2..a3c436b486314 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -38,6 +38,7 @@ #include "core/framework/plugin_ep_stream.h" #include "core/framework/transform_layout_functions.h" #include "core/framework/utils.h" +#include "core/graph/constants.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/graph/model_editor_api_types.h" @@ -2278,6 +2279,16 @@ common::Status InferenceSession::Initialize() { return Status::OK(); }; + // Enable DQ->MatMulNBits fusion if NvTensorRTRTX EP is registered. + if (execution_providers_.Get(onnxruntime::kNvTensorRTRTXExecutionProvider) != nullptr) { + if (session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "") == "") { + ORT_RETURN_IF_ERROR_SESSIONID_( + session_options_.config_options.AddConfigEntry( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "1")); + } + } + // add predefined transformers ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_, session_options_.graph_optimization_level, diff --git a/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc new file mode 100644 index 0000000000000..8aa4c88052742 --- /dev/null +++ b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc @@ -0,0 +1,595 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for the DQMatMulNBitsFusion graph transformer. +// Tests Pattern 1: DQ(3D,axis=2)->Reshape->Transpose([1,0])->[Cast]->MatMul/Gemm -> MatMulNBits +// Tests Pattern 2: DQ(2D,axis=0)->MatMul/Gemm -> MatMulNBits + +#include "core/common/span_utils.h" +#include "core/framework/int4.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/dq_matmulnbits_fusion.h" + +#include "test/test_environment.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" +#include "test/util/include/asserts.h" + +#include "gtest/gtest.h" + +#if !defined(DISABLE_CONTRIB_OPS) + +namespace onnxruntime { +namespace test { + +static std::vector MakePackedUint4(const std::vector& values) { + const size_t num_pairs = UInt4x2::CalcNumInt4Pairs(values.size()); + std::vector packed(num_pairs); + for (size_t i = 0; i < values.size(); i += 2) { + uint8_t lo = values[i] & 0x0F; + uint8_t hi = (i + 1 < values.size()) ? (values[i + 1] & 0x0F) : 0; + packed[i / 2] = UInt4x2(lo, hi); + } + return packed; +} + +static void BuildPattern1Graph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp, + bool with_cast, + bool use_gemm, + const std::vector* weight_values = nullptr, + const std::vector* scale_values = nullptr, + const std::vector* zp_values = nullptr) { + const int64_t num_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + const int64_t weight_elems = N * num_blocks * block_size; + std::vector w_vals; + if (weight_values) { + w_vals = *weight_values; + } else { + w_vals.resize(static_cast(weight_elems)); + for (size_t i = 0; i < w_vals.size(); ++i) { + w_vals[i] = static_cast(i % 16); + } + } + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer( + {N, num_blocks, block_size}, w_packed); + + std::vector s_vals; + if (scale_values) { + s_vals = *scale_values; + } else { + s_vals.resize(static_cast(N * num_blocks)); + for (size_t i = 0; i < s_vals.size(); ++i) { + s_vals[i] = 0.1f + 0.01f * static_cast(i % 10); + } + } + auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(2)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + + auto* dq_output = builder.MakeIntermediate(); + if (with_zp) { + std::vector z_vals; + if (zp_values) { + z_vals = *zp_values; + } else { + z_vals.resize(static_cast(N * num_blocks)); + for (size_t i = 0; i < z_vals.size(); ++i) { + z_vals[i] = 8; + } + } + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({N, num_blocks, 1}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + auto* reshape_shape = builder.MakeInitializer({2}, {N, K}); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output}); + + NodeAttributes tp_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs); + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs); + + NodeArg* matmul_b = transpose_output; + + if (with_cast) { + auto* cast_output = builder.MakeIntermediate(); + NodeAttributes cast_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("to", static_cast(1)), cast_attrs); + builder.AddNode("Cast", {transpose_output}, {cast_output}, "", &cast_attrs); + matmul_b = cast_output; + } + + if (use_gemm) { + builder.AddNode("Gemm", {input_a, matmul_b}, {output}); + } else { + builder.AddNode("MatMul", {input_a, matmul_b}, {output}); + } +} + +static void BuildPattern1GemmBiasGraph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp) { + const int64_t num_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + const int64_t weight_elems = N * num_blocks * block_size; + std::vector w_vals(static_cast(weight_elems)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({N, num_blocks, block_size}, w_packed); + + std::vector s_vals(static_cast(N * num_blocks)); + for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f; + auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(2)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + + if (with_zp) { + std::vector z_vals(static_cast(N * num_blocks), 8); + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({N, num_blocks, 1}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + auto* reshape_shape = builder.MakeInitializer({2}, {N, K}); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output}); + + NodeAttributes tp_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs); + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs); + + auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f)); + builder.AddNode("Gemm", {input_a, transpose_output, bias_arg}, {output}); +} + +static void BuildPattern2Graph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp, + bool use_gemm) { + const int64_t k_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + std::vector w_vals(static_cast(K * N)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({K, N}, w_packed); + + std::vector s_vals(static_cast(k_blocks * N)); + for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f + 0.01f * static_cast(i % 10); + auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + + if (with_zp) { + std::vector z_vals(static_cast(k_blocks * N), 8); + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({k_blocks, N}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + if (use_gemm) { + builder.AddNode("Gemm", {input_a, dq_output}, {output}); + } else { + builder.AddNode("MatMul", {input_a, dq_output}, {output}); + } +} + +static void BuildPattern2GemmBiasGraph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp) { + const int64_t k_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + std::vector w_vals(static_cast(K * N)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({K, N}, w_packed); + + std::vector s_vals(static_cast(k_blocks * N)); + for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f; + auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + + if (with_zp) { + std::vector z_vals(static_cast(k_blocks * N), 8); + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({k_blocks, N}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f)); + builder.AddNode("Gemm", {input_a, dq_output, bias_arg}, {output}); +} + +static void BuildPattern1WrongAxis(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size) { + const int64_t num_blocks = K / block_size; + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + std::vector w_vals(static_cast(N * num_blocks * block_size)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({N, num_blocks, block_size}, w_packed); + + std::vector s_vals(static_cast(N * num_blocks), 0.1f); + auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + + auto* reshape_shape = builder.MakeInitializer({2}, {N, K}); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output}); + + NodeAttributes tp_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs); + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs); + + builder.AddNode("MatMul", {input_a, transpose_output}, {output}); +} + +static void BuildPattern2NonConstWeight(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size) { + const int64_t k_blocks = K / block_size; + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInput({K, N}, + UInt4x2(UInt4x2::min_val, 0), + UInt4x2(UInt4x2::max_val, 0)); + + std::vector s_vals(static_cast(k_blocks * N), 0.1f); + auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + + builder.AddNode("MatMul", {input_a, dq_output}, {output}); +} + +static std::map CountOpsInGraphByDomain(const Graph& graph) { + std::map op_counts; + for (const auto& node : graph.Nodes()) { + std::string key = node.OpType(); + if (!node.Domain().empty() && node.Domain() != kOnnxDomain) { + key = node.Domain() + "." + key; + } + op_counts[key]++; + } + return op_counts; +} + +class DQMatMulNBitsFusionTest : public GraphTransformationTests {}; + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_NoZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, false, false, false); + }; + + auto pre_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["DequantizeLinear"], 1); + EXPECT_EQ(ops["Reshape"], 1); + EXPECT_EQ(ops["Transpose"], 1); + EXPECT_EQ(ops["MatMul"], 1); + return Status::OK(); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("DequantizeLinear"), 0); + EXPECT_EQ(ops.count("Reshape"), 0); + EXPECT_EQ(ops.count("Transpose"), 0); + EXPECT_EQ(ops.count("MatMul"), 0); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + const auto& attrs = node.GetAttributes(); + EXPECT_EQ(attrs.at("K").i(), K); + EXPECT_EQ(attrs.at("N").i(), N); + EXPECT_EQ(attrs.at("bits").i(), 4); + EXPECT_EQ(attrs.at("block_size").i(), block_size); + EXPECT_EQ(node.InputDefs().size(), static_cast(4)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_check, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithDefaultZP8) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, true, false, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("DequantizeLinear"), 0); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.InputDefs().size(), static_cast(3)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithNonDefaultZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector zp_vals(static_cast(N * num_blocks), 3); + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, true, false, false, + nullptr, nullptr, &zp_vals); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.InputDefs().size(), static_cast(4)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithCast) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, false, true, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Cast"), 0); + EXPECT_EQ(ops.count("MatMul"), 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_Gemm_WithBias) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1GemmBiasGraph(builder, M, N, K, block_size, true); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Gemm"), 0); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_GE(node.InputDefs().size(), static_cast(6)); + EXPECT_TRUE(node.InputDefs()[5] != nullptr); + EXPECT_TRUE(node.InputDefs()[5]->Exists()); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_Gemm_NoZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, false, false, true); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Gemm"), 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern2_MatMul_NoZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2Graph(builder, M, N, K, block_size, false, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("DequantizeLinear"), 0); + EXPECT_EQ(ops.count("MatMul"), 0); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.GetAttributes().at("K").i(), K); + EXPECT_EQ(node.GetAttributes().at("N").i(), N); + EXPECT_EQ(node.InputDefs().size(), static_cast(4)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern2_MatMul_WithDefaultZP8) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2Graph(builder, M, N, K, block_size, true, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.InputDefs().size(), static_cast(3)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern2_Gemm_WithBias) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2GemmBiasGraph(builder, M, N, K, block_size, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Gemm"), 0); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_GE(node.InputDefs().size(), static_cast(6)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Negative_Pattern1_WrongAxis) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1WrongAxis(builder, M, N, K, block_size); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("com.microsoft.MatMulNBits"), 0); + EXPECT_EQ(ops["MatMul"], 1); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Negative_Pattern2_NonConstWeight) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2NonConstWeight(builder, M, N, K, block_size); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("com.microsoft.MatMulNBits"), 0); + EXPECT_EQ(ops["DequantizeLinear"], 1); + EXPECT_EQ(ops["MatMul"], 1); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index dc687714f07cd..f7bfa3055f96d 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -9,6 +9,7 @@ #include "gtest/gtest.h" #include "core/optimizer/graph_transformer_utils.h" #include "core/session/inference_session.h" +#include "core/session/onnxruntime_session_options_config_keys.h" using namespace ONNX_NAMESPACE; @@ -69,5 +70,31 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } +TEST(GraphTransformerUtilsTests, TestDQMatMulNBitsFusionConfigWithContribGating) { + SessionOptions session_options; + const auto status = session_options.config_options.AddConfigEntry( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "1"); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + +#if defined(DISABLE_CONTRIB_OPS) + EXPECT_ANY_THROW({ + std::ignore = optimizer_utils::GenerateTransformers( + TransformerLevel::Level1, session_options, cpu_ep, logger); + }); +#else + auto transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level1, session_options, cpu_ep, logger); + + const bool has_dq_matmulnbits_fusion = + std::any_of(transformers.begin(), transformers.end(), [](const auto& transformer) { + return transformer && transformer->Name() == "DQMatMulNBitsFusion"; + }); + + EXPECT_TRUE(has_dq_matmulnbits_fusion); +#endif +} } // namespace test } // namespace onnxruntime