From 5f7df044813ada934b7c38a8ce1c6778c882176f Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 10 Mar 2026 19:11:31 +0000 Subject: [PATCH 1/7] dq_mnb_fusion change --- .../core/optimizer/dq_matmulnbits_fusion.cc | 516 +++++++++++++----- .../core/optimizer/dq_matmulnbits_fusion.h | 10 +- 2 files changed, 400 insertions(+), 126 deletions(-) diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc index f9ae13808cf2c..13339d809374d 100644 --- a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc @@ -27,8 +27,61 @@ namespace { // Utility helpers // --------------------------------------------------------------------------- -bool IsUniformPackedUint4Value(const Initializer& init, uint8_t expected_nibble) { - if (init.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) { +struct QuantTypeInfo { + int64_t bits; + bool is_signed; +}; + +// Map ONNX data types to quantization bit-width info. +std::optional GetQuantTypeInfo(int32_t data_type) { + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: + return QuantTypeInfo{2, false}; + case ONNX_NAMESPACE::TensorProto_DataType_INT2: + return QuantTypeInfo{2, true}; + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + return QuantTypeInfo{4, false}; + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + return QuantTypeInfo{4, true}; + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + return QuantTypeInfo{8, false}; + default: + return std::nullopt; + } +} + +// Extract a single N-bit element from packed data. +// For sub-byte types, elements are packed with even indices in the low bits. +uint8_t GetPackedElement(const uint8_t* packed, size_t index, size_t num_elements, int64_t bits) { + ORT_ENFORCE(index < num_elements, "GetPackedElement: index ", index, + " out of bounds (num_elements=", num_elements, ")"); + if (bits == 8) { + return packed[index]; + } + const int elems_per_byte = 8 / static_cast(bits); + const size_t byte_index = index / elems_per_byte; + const int bit_offset = static_cast((index % elems_per_byte) * bits); + const uint8_t mask = static_cast((1 << bits) - 1); + return static_cast((packed[byte_index] >> bit_offset) & mask); +} + +// Set a single N-bit element in packed data. +void SetPackedElement(uint8_t* packed, size_t index, uint8_t value, int64_t bits) { + if (bits == 8) { + packed[index] = value; + return; + } + const int elems_per_byte = 8 / static_cast(bits); + const size_t byte_index = index / elems_per_byte; + const int bit_offset = static_cast((index % elems_per_byte) * bits); + const uint8_t mask = static_cast((1 << bits) - 1); + packed[byte_index] = static_cast( + (packed[byte_index] & ~(mask << bit_offset)) | ((value & mask) << bit_offset)); +} + +bool IsUniformPackedValue(const Initializer& init, uint8_t expected_value, int64_t bits) { + const auto qtype = GetQuantTypeInfo(init.data_type()); + if (!qtype || qtype->bits != bits) { return false; } @@ -38,11 +91,10 @@ bool IsUniformPackedUint4Value(const Initializer& init, uint8_t expected_nibble) } const auto packed = init.DataAsByteSpan(); - const uint8_t expected = static_cast(expected_nibble & 0x0F); + const uint8_t mask = static_cast((1 << bits) - 1); + const uint8_t expected = static_cast(expected_value & mask); for (size_t i = 0; i < values_count; ++i) { - const uint8_t byte = packed[i / 2]; - const uint8_t value = (i % 2 == 0) ? (byte & 0x0F) : ((byte >> 4) & 0x0F); - if (value != expected) { + if (GetPackedElement(packed.data(), i, values_count, bits) != expected) { return false; } } @@ -54,16 +106,14 @@ bool HasRank2Shape(const ONNX_NAMESPACE::TensorProto& tp, int64_t dim0, int64_t return tp.dims_size() == 2 && tp.dims(0) == dim0 && tp.dims(1) == dim1; } -uint8_t GetPackedUint4Element(const uint8_t* packed, size_t index, size_t num_elements) { - ORT_ENFORCE(index < num_elements, "GetPackedUint4Element: index ", index, - " out of bounds (num_elements=", num_elements, ")"); - const uint8_t packed_byte = packed[index / 2]; - return (index % 2 == 0) ? static_cast(packed_byte & 0x0F) - : static_cast((packed_byte >> 4) & 0x0F); +// Compute the number of bytes needed to store 'count' N-bit elements. +int64_t PackedByteSize(int64_t count, int64_t bits) { + return (count * bits + 7) / 8; } -void PackUint4Rows(const Initializer& src, int64_t rows, int64_t cols, uint8_t* dst) { - const int64_t row_bytes = (cols + 1) / 2; +// Pack N-bit elements row-by-row from DQ layout to MatMulNBits layout. +void PackRows(const Initializer& src, int64_t rows, int64_t cols, int64_t bits, uint8_t* dst) { + const int64_t row_bytes = PackedByteSize(cols, bits); const size_t dst_bytes = SafeInt(rows) * row_bytes; const size_t total_elements = SafeInt(rows) * cols; memset(dst, 0, dst_bytes); @@ -72,28 +122,21 @@ void PackUint4Rows(const Initializer& src, int64_t rows, int64_t cols, uint8_t* for (int64_t r = 0; r < rows; ++r) { for (int64_t c = 0; c < cols; ++c) { const size_t src_index = SafeInt(r) * cols + c; - const uint8_t value = GetPackedUint4Element(src_packed.data(), src_index, total_elements); + const uint8_t value = GetPackedElement(src_packed.data(), src_index, total_elements, bits); - const size_t dst_index = SafeInt(r) * row_bytes + c / 2; - if ((c & 1) == 0) { - dst[dst_index] = value; - } else { - dst[dst_index] = static_cast(dst[dst_index] | (value << 4)); - } + const size_t dst_index = SafeInt(r) * row_bytes * (8 / bits) + c; + SetPackedElement(dst, dst_index, value, bits); } } } -// Transpose and pack UINT4 weights from DQ axis=0 layout [K, N] to MatMulNBits layout [N, k_blocks, blob_size]. -// Source: row-major UINT4 with quantization along K (axis=0), shape [K, N]. -// The nibble ordering follows ONNX UINT4 convention: even indices in the low nibble, -// odd indices in the high nibble of each byte. -// Dest: UINT8 [N, k_blocks, block_size/2] where each byte packs two 4-bit weights. +// Transpose and pack N-bit weights from DQ axis=0 layout [K, N] to MatMulNBits layout +// [N, k_blocks, blob_size]. blob_size = block_size * bits / 8. void TransposePackWeightsAxis0( - const uint8_t* src_packed, int64_t K, int64_t N, int64_t block_size, + const uint8_t* src_packed, int64_t K, int64_t N, int64_t block_size, int64_t bits, uint8_t* dst) { const int64_t k_blocks = (K + block_size - 1) / block_size; - const int64_t blob_size = block_size / 2; + const int64_t blob_size = block_size * bits / 8; const size_t dst_bytes = SafeInt(N) * k_blocks * blob_size; const size_t total_elements = SafeInt(K) * N; memset(dst, 0, dst_bytes); @@ -101,26 +144,23 @@ void TransposePackWeightsAxis0( for (int64_t n = 0; n < N; ++n) { for (int64_t k = 0; k < K; ++k) { const size_t src_index = SafeInt(k) * N + n; - const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); + const uint8_t val = GetPackedElement(src_packed, src_index, total_elements, bits); const int64_t kb = k / block_size; const int64_t off = k % block_size; - const size_t dst_byte = SafeInt(n) * k_blocks * blob_size + kb * blob_size + off / 2; - if (off % 2 == 0) { - dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); - } else { - dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); - } + // Destination: element index within the block's blob + const size_t dst_elem = SafeInt(n) * k_blocks * block_size + kb * block_size + off; + SetPackedElement(dst, dst_elem, val, bits); } } } -// Transpose and pack UINT4 zero points from DQ axis=0 layout [k_blocks, N] to -// MatMulNBits layout UINT8 [N, ceil(k_blocks/2)]. +// Transpose and pack N-bit zero points from DQ axis=0 layout [k_blocks, N] to +// MatMulNBits layout UINT8 [N, packed_zp_bytes_per_n]. void TransposePackZPAxis0( - const uint8_t* src_packed, int64_t k_blocks, int64_t N, + const uint8_t* src_packed, int64_t k_blocks, int64_t N, int64_t bits, uint8_t* dst) { - const int64_t zp_bytes_per_n = (k_blocks + 1) / 2; + const int64_t zp_bytes_per_n = PackedByteSize(k_blocks, bits); const size_t dst_bytes = SafeInt(N) * zp_bytes_per_n; const size_t total_elements = SafeInt(k_blocks) * N; memset(dst, 0, dst_bytes); @@ -128,33 +168,50 @@ void TransposePackZPAxis0( for (int64_t n = 0; n < N; ++n) { for (int64_t kb = 0; kb < k_blocks; ++kb) { const size_t src_index = SafeInt(kb) * N + n; - const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); + const uint8_t val = GetPackedElement(src_packed, src_index, total_elements, bits); - const size_t dst_byte = SafeInt(n) * zp_bytes_per_n + kb / 2; - if (kb % 2 == 0) { - dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); - } else { - dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); - } + const size_t dst_elem = SafeInt(n) * k_blocks + kb; + SetPackedElement(dst, dst_elem, val, bits); } } } +// Returns the Cast node's target element type (the "to" attribute), or nullopt if invalid. +std::optional GetCastToType(const Node& cast_node) { + const auto* to_attr = graph_utils::GetNodeAttribute(cast_node, "to"); + if (!to_attr) return std::nullopt; + return static_cast(to_attr->i()); +} + +bool IsFloatOrFloat16(int32_t dt) { + return dt == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + dt == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +} + // --------------------------------------------------------------------------- // Match structs // --------------------------------------------------------------------------- struct FusionMatch { NodeIndex matmul_idx; - std::optional cast_idx; + std::optional weight_cast_idx; // Cast on weight path (between Transpose and MatMul) NodeIndex transpose_idx; NodeIndex reshape_idx; NodeIndex dq_idx; + int64_t bits; + std::optional input_a_cast_idx; // Cast on input A + std::optional output_cast_idx; // Cast on MatMul output + int32_t effective_dt_a; // T1 for MatMulNBits (scale type) }; struct DirectDQMatch { NodeIndex matmul_idx; NodeIndex dq_idx; + int64_t bits; + std::optional weight_cast_idx; // Cast on weight path (between DQ and MatMul) + std::optional input_a_cast_idx; // Cast on input A + std::optional output_cast_idx; // Cast on MatMul output + int32_t effective_dt_a; // T1 for MatMulNBits (scale type) }; // --------------------------------------------------------------------------- @@ -188,6 +245,7 @@ bool ValidateGemmForFusion(const Node& gemm_node, int64_t N) { // --------------------------------------------------------------------------- // Pattern 1 matching: DQ -> Reshape -> Transpose -> [Cast] -> MatMul/Gemm +// With optional Cast on input A and/or MatMul output for FP16 models. // --------------------------------------------------------------------------- std::vector CollectReshapeTransposeMatches( @@ -205,12 +263,13 @@ std::vector CollectReshapeTransposeMatches( const auto& mm_inputs = node->InputDefs(); if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; - const Node* cast_node = nullptr; + // Trace weight path: MatMul input B <- [Cast] <- Transpose <- Reshape <- DQ + const Node* weight_cast_node = nullptr; const Node* transpose_node = graph.GetProducerNode(mm_inputs[1]->Name()); if (transpose_node && transpose_node->OpType() == "Cast") { - cast_node = transpose_node; - if (cast_node->GetOutputEdgesCount() != 1) continue; - const auto& cast_inputs = cast_node->InputDefs(); + weight_cast_node = transpose_node; + if (weight_cast_node->GetOutputEdgesCount() != 1) continue; + const auto& cast_inputs = weight_cast_node->InputDefs(); if (cast_inputs.empty() || !cast_inputs[0] || !cast_inputs[0]->Exists()) continue; transpose_node = graph.GetProducerNode(cast_inputs[0]->Name()); } @@ -243,11 +302,14 @@ std::vector CollectReshapeTransposeMatches( if (block_size < 16 || ((block_size - 1) & block_size)) continue; } + // Validate weight type: must be a supported quantized type const auto* weight_arg = dq_node->InputDefs()[0]; if (!weight_arg || !weight_arg->Exists()) continue; const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); if (!weight_const_tp) continue; - if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + const auto weight_qtype = GetQuantTypeInfo(weight_const_tp->data_type()); + if (!weight_qtype) continue; + const int64_t bits = weight_qtype->bits; if (weight_const_tp->dims_size() != 3) continue; const int64_t N = weight_const_tp->dims(0); const int64_t blocks = weight_const_tp->dims(1); @@ -256,18 +318,59 @@ std::vector CollectReshapeTransposeMatches( if (bs_dim != block_size) continue; const int64_t K = SafeInt(blocks) * bs_dim; + // Scale type determines the effective T1 for MatMulNBits const auto* scale_arg = dq_node->InputDefs()[1]; if (!scale_arg || !scale_arg->Exists()) continue; const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); if (!scale_const_tp) continue; int32_t dt_scale = scale_const_tp->data_type(); - if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + if (!IsFloatOrFloat16(dt_scale)) continue; + // Check input A type, looking through optional Cast const auto* a_arg = mm_inputs[0]; if (!a_arg || !a_arg->TypeAsProto()) continue; int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); - if (dt_a != dt_scale) continue; + + const Node* input_a_cast_node = nullptr; + int32_t effective_dt_a = dt_a; + + if (dt_a != dt_scale) { + // Check if input A is produced by a Cast from dt_scale + const Node* a_producer = graph.GetProducerNode(a_arg->Name()); + if (a_producer && a_producer->OpType() == "Cast") { + const auto cast_to = GetCastToType(*a_producer); + if (cast_to && *cast_to == dt_a) { + const auto* cast_in = a_producer->InputDefs().empty() ? nullptr : a_producer->InputDefs()[0]; + if (cast_in && cast_in->TypeAsProto()) { + int32_t dt_cast_in = cast_in->TypeAsProto()->tensor_type().elem_type(); + if (dt_cast_in == dt_scale && a_producer->GetOutputEdgesCount() == 1) { + input_a_cast_node = a_producer; + effective_dt_a = dt_scale; + } + } + } + } + if (effective_dt_a != dt_scale) continue; + } + + // Validate weight-path Cast: must cast to dt_a (the MatMul compute type) + if (weight_cast_node) { + const auto cast_to = GetCastToType(*weight_cast_node); + if (!cast_to || *cast_to != dt_a) continue; + } + + // Check for Cast on MatMul output + const Node* output_cast_node = nullptr; + if (node->GetOutputEdgesCount() == 1) { + const auto edge = node->OutputEdgesBegin(); + const Node& consumer = edge->GetNode(); + if (consumer.OpType() == "Cast") { + const auto cast_to = GetCastToType(consumer); + if (cast_to && *cast_to == dt_scale && consumer.GetOutputEdgesCount() >= 1) { + output_cast_node = &consumer; + } + } + } const auto* reshape_shape_arg = reshape_node->InputDefs().size() > 1 ? reshape_node->InputDefs()[1] : nullptr; @@ -337,37 +440,35 @@ std::vector CollectReshapeTransposeMatches( if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; - if (cast_node) { - const auto* cast_in = cast_node->InputDefs().empty() ? nullptr : cast_node->InputDefs()[0]; - const auto* cast_out = cast_node->OutputDefs().empty() ? nullptr : cast_node->OutputDefs()[0]; - if (!cast_in || !cast_out || !cast_in->TypeAsProto() || !cast_out->TypeAsProto()) continue; - if (cast_in->TypeAsProto()->tensor_type().elem_type() != - cast_out->TypeAsProto()->tensor_type().elem_type()) { - continue; - } - } - + // Validate zero-point type matches weight type const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; bool has_zp = zp_arg && zp_arg->Exists(); if (has_zp) { const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); - if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (!zp_const_tp) continue; + const auto zp_qtype = GetQuantTypeInfo(zp_const_tp->data_type()); + if (!zp_qtype || zp_qtype->bits != bits) continue; } LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched pattern at MatMul node '" - << node->Name() << "'"; + << node->Name() << "' (bits=" << bits << ")"; matches.push_back({node->Index(), - cast_node ? std::optional(cast_node->Index()) : std::nullopt, + weight_cast_node ? std::optional(weight_cast_node->Index()) : std::nullopt, transpose_node->Index(), - reshape_node->Index(), dq_node->Index()}); + reshape_node->Index(), dq_node->Index(), + bits, + input_a_cast_node ? std::optional(input_a_cast_node->Index()) : std::nullopt, + output_cast_node ? std::optional(output_cast_node->Index()) : std::nullopt, + dt_scale}); } return matches; } // --------------------------------------------------------------------------- -// Pattern 2 matching: direct DQ(axis=0, 2D UINT4) -> MatMul/Gemm +// Pattern 2 matching: direct DQ(axis=0, 2D) -> [Cast] -> MatMul/Gemm +// With optional Cast on input A and/or MatMul output for FP16 models. // --------------------------------------------------------------------------- std::vector CollectDirectDQMatches( @@ -387,7 +488,17 @@ std::vector CollectDirectDQMatches( const auto& mm_inputs = node->InputDefs(); if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; + // Trace weight path: MatMul input B <- [Cast] <- DQ + const Node* weight_cast_node = nullptr; const Node* dq_node = graph.GetProducerNode(mm_inputs[1]->Name()); + if (dq_node && dq_node->OpType() == "Cast") { + weight_cast_node = dq_node; + if (weight_cast_node->GetOutputEdgesCount() != 1) continue; + const auto& cast_inputs = weight_cast_node->InputDefs(); + if (cast_inputs.empty() || !cast_inputs[0] || !cast_inputs[0]->Exists()) continue; + dq_node = graph.GetProducerNode(cast_inputs[0]->Name()); + } + if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue; if (dq_node->GetOutputEdgesCount() != 1) continue; @@ -408,7 +519,9 @@ std::vector CollectDirectDQMatches( if (!weight_arg || !weight_arg->Exists()) continue; const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); if (!weight_const_tp) continue; - if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + const auto weight_qtype = GetQuantTypeInfo(weight_const_tp->data_type()); + if (!weight_qtype) continue; + const int64_t bits = weight_qtype->bits; if (weight_const_tp->dims_size() != 2) continue; const int64_t K = weight_const_tp->dims(0); const int64_t N = weight_const_tp->dims(1); @@ -420,28 +533,74 @@ std::vector CollectDirectDQMatches( const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); if (!scale_const_tp) continue; int32_t dt_scale = scale_const_tp->data_type(); - if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + if (!IsFloatOrFloat16(dt_scale)) continue; if (!HasRank2Shape(*scale_const_tp, k_blocks, N)) continue; + // Check input A type, looking through optional Cast const auto* a_arg = mm_inputs[0]; if (!a_arg || !a_arg->TypeAsProto()) continue; int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); - if (dt_a != dt_scale) continue; + + const Node* input_a_cast_node = nullptr; + int32_t effective_dt_a = dt_a; + + if (dt_a != dt_scale) { + const Node* a_producer = graph.GetProducerNode(a_arg->Name()); + if (a_producer && a_producer->OpType() == "Cast") { + const auto cast_to = GetCastToType(*a_producer); + if (cast_to && *cast_to == dt_a) { + const auto* cast_in = a_producer->InputDefs().empty() ? nullptr : a_producer->InputDefs()[0]; + if (cast_in && cast_in->TypeAsProto()) { + int32_t dt_cast_in = cast_in->TypeAsProto()->tensor_type().elem_type(); + if (dt_cast_in == dt_scale && a_producer->GetOutputEdgesCount() == 1) { + input_a_cast_node = a_producer; + effective_dt_a = dt_scale; + } + } + } + } + if (effective_dt_a != dt_scale) continue; + } + + // Validate weight-path Cast + if (weight_cast_node) { + const auto cast_to = GetCastToType(*weight_cast_node); + if (!cast_to || *cast_to != dt_a) continue; + } + + // Check for Cast on MatMul output + const Node* output_cast_node = nullptr; + if (node->GetOutputEdgesCount() == 1) { + const auto edge = node->OutputEdgesBegin(); + const Node& consumer = edge->GetNode(); + if (consumer.OpType() == "Cast") { + const auto cast_to = GetCastToType(consumer); + if (cast_to && *cast_to == dt_scale && consumer.GetOutputEdgesCount() >= 1) { + output_cast_node = &consumer; + } + } + } const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; bool has_zp = zp_arg && zp_arg->Exists(); if (has_zp) { const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); - if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (!zp_const_tp) continue; + const auto zp_qtype = GetQuantTypeInfo(zp_const_tp->data_type()); + if (!zp_qtype || zp_qtype->bits != bits) continue; if (!HasRank2Shape(*zp_const_tp, k_blocks, N)) continue; } if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched direct DQ->MatMul pattern at node '" - << node->Name() << "' (K=" << K << ", N=" << N << ", block_size=" << block_size << ")"; - direct_matches.push_back({node->Index(), dq_node->Index()}); + << node->Name() << "' (K=" << K << ", N=" << N + << ", block_size=" << block_size << ", bits=" << bits << ")"; + direct_matches.push_back({node->Index(), dq_node->Index(), bits, + weight_cast_node ? std::optional(weight_cast_node->Index()) : std::nullopt, + input_a_cast_node ? std::optional(input_a_cast_node->Index()) : std::nullopt, + output_cast_node ? std::optional(output_cast_node->Index()) : std::nullopt, + dt_scale}); } return direct_matches; @@ -459,15 +618,22 @@ void ApplyReshapeTransposeFusions( const logging::Logger& logger) { for (const auto& match : matches) { const Node* mm_node = graph.GetNode(match.matmul_idx); - const Node* cast_node = match.cast_idx ? graph.GetNode(*match.cast_idx) : nullptr; + const Node* weight_cast_node = match.weight_cast_idx ? graph.GetNode(*match.weight_cast_idx) : nullptr; const Node* tp_node = graph.GetNode(match.transpose_idx); const Node* dq_node = graph.GetNode(match.dq_idx); const Node* reshape_node = graph.GetNode(match.reshape_idx); + const Node* input_a_cast_node = match.input_a_cast_idx ? graph.GetNode(*match.input_a_cast_idx) : nullptr; + const Node* output_cast_node = match.output_cast_idx ? graph.GetNode(*match.output_cast_idx) : nullptr; if (!mm_node || !tp_node || !dq_node || !reshape_node || - (match.cast_idx && !cast_node)) { + (match.weight_cast_idx && !weight_cast_node) || + (match.input_a_cast_idx && !input_a_cast_node) || + (match.output_cast_idx && !output_cast_node)) { continue; } + const int64_t bits = match.bits; + const int32_t effective_dt_a = match.effective_dt_a; + const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; @@ -485,8 +651,8 @@ void ApplyReshapeTransposeFusions( if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; } - if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || - weight_tp->dims_size() != 3) { + const auto weight_qtype = GetQuantTypeInfo(weight_tp->data_type()); + if (!weight_qtype || weight_qtype->bits != bits || weight_tp->dims_size() != 3) { continue; } @@ -495,12 +661,11 @@ void ApplyReshapeTransposeFusions( const int64_t bs_dim = weight_tp->dims(2); if (N <= 0 || quant_num <= 0 || bs_dim <= 0 || bs_dim != block_size) continue; const int64_t K = SafeInt(quant_num) * bs_dim; - const int64_t blob_bytes = (block_size + 1) / 2; + const int64_t blob_bytes = PackedByteSize(block_size, bits); Initializer weight_src(graph, *weight_tp, graph.ModelPath()); Initializer scale_src(graph, *scale_tp, graph.ModelPath()); - if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + if (!IsFloatOrFloat16(scale_src.data_type())) { continue; } @@ -523,10 +688,11 @@ void ApplyReshapeTransposeFusions( std::string zp_dst_name; std::optional zp_dst; - const int64_t zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); + const int64_t zp_packed_size = SafeInt(N) * PackedByteSize(quant_num, bits); - bool elide_default_uint4_zp8_input = false; + bool elide_default_zp = false; std::optional zp_src; + const uint8_t mnb_default_zp = static_cast(1 << (bits - 1)); // 2^(bits-1) const auto weight_bytes = weight_src.DataAsByteSpan(); if (weight_bytes.size() != static_cast(weight_dst.SizeInBytes())) continue; @@ -542,30 +708,29 @@ void ApplyReshapeTransposeFusions( if (zp_tp) { zp_src.emplace(graph, *zp_tp, graph.ModelPath()); - if (zp_src->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + const auto zp_qtype = GetQuantTypeInfo(zp_src->data_type()); + if (!zp_qtype || zp_qtype->bits != bits) continue; if (zp_src->size() != static_cast(N * quant_num)) continue; - const bool is_default_uint4_8 = - IsUniformPackedUint4Value(*zp_src, /*expected_nibble*/ 8); - if (is_default_uint4_8) { - elide_default_uint4_zp8_input = true; + if (IsUniformPackedValue(*zp_src, mnb_default_zp, bits)) { + elide_default_zp = true; } else { zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); - zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); - PackUint4Rows(*zp_src, N, quant_num, zp_dst->MutableData()); + zp_dst = Tensor(uint8_type, TensorShape{zp_packed_size}, cpu_allocator); + PackRows(*zp_src, N, quant_num, bits, zp_dst->MutableData()); } } else { - // DequantizeLinear default zero-point for uint4 is 0, while MatMulNBits - // default is 8. Emit explicit zeros to preserve semantics. + // DequantizeLinear default zero-point is 0, while MatMulNBits + // default is 2^(bits-1). Emit explicit zeros to preserve semantics. zp_dst_name = graph.GenerateNodeArgName("fused_DQ_zp_mnb"); - zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + zp_dst = Tensor(uint8_type, TensorShape{zp_packed_size}, cpu_allocator); memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); } auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); std::optional zp_mnb_tp; - if (zp_dst && !elide_default_uint4_zp8_input) { + if (zp_dst && !elide_default_zp) { zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } @@ -573,11 +738,16 @@ void ApplyReshapeTransposeFusions( utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); - utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); + // Determine input A: use pre-Cast value if input A Cast is being removed + NodeArg* mnb_input_a = input_a_cast_node + ? const_cast(input_a_cast_node->InputDefs()[0]) + : const_cast(mm_node->InputDefs()[0]); + std::vector mnb_inputs; - mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); + mnb_inputs.push_back(mnb_input_a); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { @@ -598,8 +768,20 @@ void ApplyReshapeTransposeFusions( fused_with_bias = true; } + // Determine output: if output Cast exists, take over its output; otherwise MatMul's output std::vector mnb_outputs; - mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + if (output_cast_node) { + mnb_outputs.push_back(const_cast(output_cast_node->OutputDefs()[0])); + } else if (effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + // MatMulNBits outputs T1 (=effective_dt_a) but consumers expect the original type. + // Create a new intermediate output for MatMulNBits and insert a Cast after it. + auto& mnb_out_arg = graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("mnb_out"), + nullptr); + mnb_outputs.push_back(&mnb_out_arg); + } else { + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + } auto& mnb_node = graph.AddNode( graph.GenerateNodeName("DQFusedMatMulNBits"), @@ -608,12 +790,37 @@ void ApplyReshapeTransposeFusions( mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); + // If we need a Cast after MatMulNBits (no output_cast_node to absorb the type difference) + if (!output_cast_node && + effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + int32_t original_dt = mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + NodeAttributes cast_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("to", static_cast(original_dt)), cast_attrs); + graph.AddNode( + graph.GenerateNodeName("DQFusedMatMulNBits_Cast"), + "Cast", "Cast MNB output to original type", + {mnb_outputs[0]}, + {const_cast(mm_node->OutputDefs()[0])}, + &cast_attrs); + } + + // Remove nodes in reverse dependency order + if (output_cast_node) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.output_cast_idx.value())); + graph.RemoveNode(match.output_cast_idx.value()); + } + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); graph.RemoveNode(match.matmul_idx); - if (match.cast_idx && graph.GetNode(*match.cast_idx)) { - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(*match.cast_idx)); - graph.RemoveNode(*match.cast_idx); + if (input_a_cast_node) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.input_a_cast_idx.value())); + graph.RemoveNode(match.input_a_cast_idx.value()); + } + + if (match.weight_cast_idx && graph.GetNode(*match.weight_cast_idx)) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(*match.weight_cast_idx)); + graph.RemoveNode(*match.weight_cast_idx); } graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.transpose_idx)); @@ -626,16 +833,19 @@ void ApplyReshapeTransposeFusions( graph.RemoveNode(match.dq_idx); LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused DQ+Reshape+Transpose" - << (match.cast_idx ? "+Cast" : "") + << (match.weight_cast_idx ? "+Cast" : "") << "+MatMul/Gemm -> MatMulNBits" + << " (bits=" << bits << ")" << (fused_with_bias ? " (bias preserved)" : "") - << (elide_default_uint4_zp8_input ? " (default UINT4 zp8 elided)" : ""); + << (elide_default_zp ? " (default zp elided)" : "") + << (input_a_cast_node ? " (input Cast removed)" : "") + << (output_cast_node ? " (output Cast removed)" : ""); modified = true; } } // --------------------------------------------------------------------------- -// Pattern 2 rewriting: direct DQ(axis=0) + MatMul/Gemm -> MatMulNBits +// Pattern 2 rewriting: direct DQ(axis=0) + [Cast] + MatMul/Gemm -> MatMulNBits // --------------------------------------------------------------------------- void ApplyDirectDQFusions( @@ -647,7 +857,18 @@ void ApplyDirectDQFusions( for (const auto& match : matches) { const Node* mm_node = graph.GetNode(match.matmul_idx); const Node* dq_node = graph.GetNode(match.dq_idx); - if (!mm_node || !dq_node) continue; + const Node* weight_cast_node = match.weight_cast_idx ? graph.GetNode(*match.weight_cast_idx) : nullptr; + const Node* input_a_cast_node = match.input_a_cast_idx ? graph.GetNode(*match.input_a_cast_idx) : nullptr; + const Node* output_cast_node = match.output_cast_idx ? graph.GetNode(*match.output_cast_idx) : nullptr; + if (!mm_node || !dq_node || + (match.weight_cast_idx && !weight_cast_node) || + (match.input_a_cast_idx && !input_a_cast_node) || + (match.output_cast_idx && !output_cast_node)) { + continue; + } + + const int64_t bits = match.bits; + const int32_t effective_dt_a = match.effective_dt_a; const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -666,14 +887,14 @@ void ApplyDirectDQFusions( if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; } - if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || - weight_tp->dims_size() != 2) continue; + const auto weight_qtype = GetQuantTypeInfo(weight_tp->data_type()); + if (!weight_qtype || weight_qtype->bits != bits || weight_tp->dims_size() != 2) continue; const int64_t K = weight_tp->dims(0); const int64_t N = weight_tp->dims(1); if (K <= 0 || N <= 0 || block_size <= 0 || K % block_size != 0) continue; const int64_t k_blocks = K / block_size; - const int64_t blob_bytes = block_size / 2; + const int64_t blob_bytes = block_size * bits / 8; if (!HasRank2Shape(*scale_tp, k_blocks, N)) continue; if (zp_tp && !HasRank2Shape(*zp_tp, k_blocks, N)) continue; @@ -681,8 +902,7 @@ void ApplyDirectDQFusions( const size_t required_weight_bytes = SafeInt(N) * k_blocks * blob_bytes; if (weight_src.DataAsByteSpan().size() < required_weight_bytes) continue; Initializer scale_src(graph, *scale_tp, graph.ModelPath()); - if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + if (!IsFloatOrFloat16(scale_src.data_type())) continue; auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum( ONNX_NAMESPACE::TensorProto_DataType_UINT8) @@ -694,7 +914,7 @@ void ApplyDirectDQFusions( auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb"); auto weight_dst = Tensor(uint8_type, TensorShape{N, k_blocks, blob_bytes}, cpu_allocator); - TransposePackWeightsAxis0(weight_src.DataAsByteSpan().data(), K, N, block_size, + TransposePackWeightsAxis0(weight_src.DataAsByteSpan().data(), K, N, block_size, bits, weight_dst.MutableData()); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb"); @@ -718,25 +938,27 @@ void ApplyDirectDQFusions( std::string zp_dst_name; std::optional zp_dst; - const int64_t zp_bytes_total = SafeInt(N) * ((k_blocks + 1) / 2); + const int64_t zp_bytes_total = SafeInt(N) * PackedByteSize(k_blocks, bits); bool elide_zp = false; + const uint8_t mnb_default_zp = static_cast(1 << (bits - 1)); if (zp_tp) { Initializer zp_src(graph, *zp_tp, graph.ModelPath()); - if (zp_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + const auto zp_qtype = GetQuantTypeInfo(zp_src.data_type()); + if (!zp_qtype || zp_qtype->bits != bits) continue; if (zp_src.size() != static_cast(k_blocks * N)) continue; - if (IsUniformPackedUint4Value(zp_src, 8)) { + if (IsUniformPackedValue(zp_src, mnb_default_zp, bits)) { elide_zp = true; } else { zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); - TransposePackZPAxis0(zp_src.DataAsByteSpan().data(), k_blocks, N, + TransposePackZPAxis0(zp_src.DataAsByteSpan().data(), k_blocks, N, bits, zp_dst->MutableData()); } } else { - // DQ default ZP for UINT4 is 0, MatMulNBits default is 8. Emit explicit zeros. + // DQ default ZP is 0, MatMulNBits default is 2^(bits-1). Emit explicit zeros. zp_dst_name = graph.GenerateNodeArgName("direct_DQ_zp_mnb"); zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); @@ -753,11 +975,16 @@ void ApplyDirectDQFusions( utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); - utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); + // Determine input A: use pre-Cast value if input A Cast is being removed + NodeArg* mnb_input_a = input_a_cast_node + ? const_cast(input_a_cast_node->InputDefs()[0]) + : const_cast(mm_node->InputDefs()[0]); + std::vector mnb_inputs; - mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); + mnb_inputs.push_back(mnb_input_a); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { @@ -777,8 +1004,18 @@ void ApplyDirectDQFusions( fused_with_bias = true; } + // Determine output: if output Cast exists, take over its output; otherwise MatMul's output std::vector mnb_outputs; - mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + if (output_cast_node) { + mnb_outputs.push_back(const_cast(output_cast_node->OutputDefs()[0])); + } else if (effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + auto& mnb_out_arg = graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("mnb_out"), + nullptr); + mnb_outputs.push_back(&mnb_out_arg); + } else { + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + } auto& mnb_node = graph.AddNode( graph.GenerateNodeName("DirectDQFusedMatMulNBits"), @@ -787,16 +1024,49 @@ void ApplyDirectDQFusions( mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); + // If we need a Cast after MatMulNBits + if (!output_cast_node && + effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + int32_t original_dt = mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + NodeAttributes cast_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("to", static_cast(original_dt)), cast_attrs); + graph.AddNode( + graph.GenerateNodeName("DirectDQFusedMatMulNBits_Cast"), + "Cast", "Cast MNB output to original type", + {mnb_outputs[0]}, + {const_cast(mm_node->OutputDefs()[0])}, + &cast_attrs); + } + + // Remove nodes in reverse dependency order + if (output_cast_node) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.output_cast_idx.value())); + graph.RemoveNode(match.output_cast_idx.value()); + } + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); graph.RemoveNode(match.matmul_idx); + if (input_a_cast_node) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.input_a_cast_idx.value())); + graph.RemoveNode(match.input_a_cast_idx.value()); + } + + if (weight_cast_node) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.weight_cast_idx.value())); + graph.RemoveNode(match.weight_cast_idx.value()); + } + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx)); graph.RemoveNode(match.dq_idx); LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused direct DQ(axis=0)+MatMul/Gemm -> MatMulNBits" - << " (K=" << K << ", N=" << N << ", block_size=" << block_size << ")" + << " (K=" << K << ", N=" << N << ", block_size=" << block_size + << ", bits=" << bits << ")" << (fused_with_bias ? " (bias preserved)" : "") - << (elide_zp ? " (default UINT4 zp8 elided)" : ""); + << (elide_zp ? " (default zp elided)" : "") + << (input_a_cast_node ? " (input Cast removed)" : "") + << (output_cast_node ? " (output Cast removed)" : ""); modified = true; } } diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h index 97c0debd760c0..b5d3e60f1be19 100644 --- a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h @@ -13,10 +13,14 @@ namespace onnxruntime { // Fuses DequantizeLinear chains back into a single MatMulNBits contrib op. // -// Supported patterns: -// Pattern 1: DQ(3D, UINT4, axis=2) -> Reshape(2D) -> Transpose([1,0]) +// Supported patterns (weight types: UINT2, INT2, UINT4, INT4, UINT8): +// Pattern 1: DQ(3D, axis=2) -> Reshape(2D) -> Transpose([1,0]) // -> [optional Cast] -> MatMul/Gemm => MatMulNBits -// Pattern 2: DQ(2D, UINT4, axis=0) -> MatMul/Gemm => MatMulNBits +// Pattern 2: DQ(2D, axis=0) -> [optional Cast] -> MatMul/Gemm => MatMulNBits +// +// FP16 Cast handling: Cast nodes on input A (FP16→FP32), the weight path +// (FP16→FP32), and output (FP32→FP16) are absorbed into the fusion so that +// MatMulNBits operates directly on FP16 inputs/outputs. // // These patterns are produced when a quantized model goes through external // toolchains that lower MatMulNBits to DQ + reshape/transpose + MatMul From 7618741511afd8d3358ad6a24c8cdd554ba484f3 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 10 Mar 2026 19:11:38 +0000 Subject: [PATCH 2/7] Revert "dq_mnb_fusion change" This reverts commit 5f7df044813ada934b7c38a8ce1c6778c882176f. --- .../core/optimizer/dq_matmulnbits_fusion.cc | 516 +++++------------- .../core/optimizer/dq_matmulnbits_fusion.h | 10 +- 2 files changed, 126 insertions(+), 400 deletions(-) diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc index 13339d809374d..f9ae13808cf2c 100644 --- a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc @@ -27,61 +27,8 @@ namespace { // Utility helpers // --------------------------------------------------------------------------- -struct QuantTypeInfo { - int64_t bits; - bool is_signed; -}; - -// Map ONNX data types to quantization bit-width info. -std::optional GetQuantTypeInfo(int32_t data_type) { - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_UINT2: - return QuantTypeInfo{2, false}; - case ONNX_NAMESPACE::TensorProto_DataType_INT2: - return QuantTypeInfo{2, true}; - case ONNX_NAMESPACE::TensorProto_DataType_UINT4: - return QuantTypeInfo{4, false}; - case ONNX_NAMESPACE::TensorProto_DataType_INT4: - return QuantTypeInfo{4, true}; - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - return QuantTypeInfo{8, false}; - default: - return std::nullopt; - } -} - -// Extract a single N-bit element from packed data. -// For sub-byte types, elements are packed with even indices in the low bits. -uint8_t GetPackedElement(const uint8_t* packed, size_t index, size_t num_elements, int64_t bits) { - ORT_ENFORCE(index < num_elements, "GetPackedElement: index ", index, - " out of bounds (num_elements=", num_elements, ")"); - if (bits == 8) { - return packed[index]; - } - const int elems_per_byte = 8 / static_cast(bits); - const size_t byte_index = index / elems_per_byte; - const int bit_offset = static_cast((index % elems_per_byte) * bits); - const uint8_t mask = static_cast((1 << bits) - 1); - return static_cast((packed[byte_index] >> bit_offset) & mask); -} - -// Set a single N-bit element in packed data. -void SetPackedElement(uint8_t* packed, size_t index, uint8_t value, int64_t bits) { - if (bits == 8) { - packed[index] = value; - return; - } - const int elems_per_byte = 8 / static_cast(bits); - const size_t byte_index = index / elems_per_byte; - const int bit_offset = static_cast((index % elems_per_byte) * bits); - const uint8_t mask = static_cast((1 << bits) - 1); - packed[byte_index] = static_cast( - (packed[byte_index] & ~(mask << bit_offset)) | ((value & mask) << bit_offset)); -} - -bool IsUniformPackedValue(const Initializer& init, uint8_t expected_value, int64_t bits) { - const auto qtype = GetQuantTypeInfo(init.data_type()); - if (!qtype || qtype->bits != bits) { +bool IsUniformPackedUint4Value(const Initializer& init, uint8_t expected_nibble) { + if (init.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) { return false; } @@ -91,10 +38,11 @@ bool IsUniformPackedValue(const Initializer& init, uint8_t expected_value, int64 } const auto packed = init.DataAsByteSpan(); - const uint8_t mask = static_cast((1 << bits) - 1); - const uint8_t expected = static_cast(expected_value & mask); + const uint8_t expected = static_cast(expected_nibble & 0x0F); for (size_t i = 0; i < values_count; ++i) { - if (GetPackedElement(packed.data(), i, values_count, bits) != expected) { + const uint8_t byte = packed[i / 2]; + const uint8_t value = (i % 2 == 0) ? (byte & 0x0F) : ((byte >> 4) & 0x0F); + if (value != expected) { return false; } } @@ -106,14 +54,16 @@ bool HasRank2Shape(const ONNX_NAMESPACE::TensorProto& tp, int64_t dim0, int64_t return tp.dims_size() == 2 && tp.dims(0) == dim0 && tp.dims(1) == dim1; } -// Compute the number of bytes needed to store 'count' N-bit elements. -int64_t PackedByteSize(int64_t count, int64_t bits) { - return (count * bits + 7) / 8; +uint8_t GetPackedUint4Element(const uint8_t* packed, size_t index, size_t num_elements) { + ORT_ENFORCE(index < num_elements, "GetPackedUint4Element: index ", index, + " out of bounds (num_elements=", num_elements, ")"); + const uint8_t packed_byte = packed[index / 2]; + return (index % 2 == 0) ? static_cast(packed_byte & 0x0F) + : static_cast((packed_byte >> 4) & 0x0F); } -// Pack N-bit elements row-by-row from DQ layout to MatMulNBits layout. -void PackRows(const Initializer& src, int64_t rows, int64_t cols, int64_t bits, uint8_t* dst) { - const int64_t row_bytes = PackedByteSize(cols, bits); +void PackUint4Rows(const Initializer& src, int64_t rows, int64_t cols, uint8_t* dst) { + const int64_t row_bytes = (cols + 1) / 2; const size_t dst_bytes = SafeInt(rows) * row_bytes; const size_t total_elements = SafeInt(rows) * cols; memset(dst, 0, dst_bytes); @@ -122,21 +72,28 @@ void PackRows(const Initializer& src, int64_t rows, int64_t cols, int64_t bits, for (int64_t r = 0; r < rows; ++r) { for (int64_t c = 0; c < cols; ++c) { const size_t src_index = SafeInt(r) * cols + c; - const uint8_t value = GetPackedElement(src_packed.data(), src_index, total_elements, bits); + const uint8_t value = GetPackedUint4Element(src_packed.data(), src_index, total_elements); - const size_t dst_index = SafeInt(r) * row_bytes * (8 / bits) + c; - SetPackedElement(dst, dst_index, value, bits); + const size_t dst_index = SafeInt(r) * row_bytes + c / 2; + if ((c & 1) == 0) { + dst[dst_index] = value; + } else { + dst[dst_index] = static_cast(dst[dst_index] | (value << 4)); + } } } } -// Transpose and pack N-bit weights from DQ axis=0 layout [K, N] to MatMulNBits layout -// [N, k_blocks, blob_size]. blob_size = block_size * bits / 8. +// Transpose and pack UINT4 weights from DQ axis=0 layout [K, N] to MatMulNBits layout [N, k_blocks, blob_size]. +// Source: row-major UINT4 with quantization along K (axis=0), shape [K, N]. +// The nibble ordering follows ONNX UINT4 convention: even indices in the low nibble, +// odd indices in the high nibble of each byte. +// Dest: UINT8 [N, k_blocks, block_size/2] where each byte packs two 4-bit weights. void TransposePackWeightsAxis0( - const uint8_t* src_packed, int64_t K, int64_t N, int64_t block_size, int64_t bits, + const uint8_t* src_packed, int64_t K, int64_t N, int64_t block_size, uint8_t* dst) { const int64_t k_blocks = (K + block_size - 1) / block_size; - const int64_t blob_size = block_size * bits / 8; + const int64_t blob_size = block_size / 2; const size_t dst_bytes = SafeInt(N) * k_blocks * blob_size; const size_t total_elements = SafeInt(K) * N; memset(dst, 0, dst_bytes); @@ -144,23 +101,26 @@ void TransposePackWeightsAxis0( for (int64_t n = 0; n < N; ++n) { for (int64_t k = 0; k < K; ++k) { const size_t src_index = SafeInt(k) * N + n; - const uint8_t val = GetPackedElement(src_packed, src_index, total_elements, bits); + const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); const int64_t kb = k / block_size; const int64_t off = k % block_size; - // Destination: element index within the block's blob - const size_t dst_elem = SafeInt(n) * k_blocks * block_size + kb * block_size + off; - SetPackedElement(dst, dst_elem, val, bits); + const size_t dst_byte = SafeInt(n) * k_blocks * blob_size + kb * blob_size + off / 2; + if (off % 2 == 0) { + dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); + } else { + dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); + } } } } -// Transpose and pack N-bit zero points from DQ axis=0 layout [k_blocks, N] to -// MatMulNBits layout UINT8 [N, packed_zp_bytes_per_n]. +// Transpose and pack UINT4 zero points from DQ axis=0 layout [k_blocks, N] to +// MatMulNBits layout UINT8 [N, ceil(k_blocks/2)]. void TransposePackZPAxis0( - const uint8_t* src_packed, int64_t k_blocks, int64_t N, int64_t bits, + const uint8_t* src_packed, int64_t k_blocks, int64_t N, uint8_t* dst) { - const int64_t zp_bytes_per_n = PackedByteSize(k_blocks, bits); + const int64_t zp_bytes_per_n = (k_blocks + 1) / 2; const size_t dst_bytes = SafeInt(N) * zp_bytes_per_n; const size_t total_elements = SafeInt(k_blocks) * N; memset(dst, 0, dst_bytes); @@ -168,50 +128,33 @@ void TransposePackZPAxis0( for (int64_t n = 0; n < N; ++n) { for (int64_t kb = 0; kb < k_blocks; ++kb) { const size_t src_index = SafeInt(kb) * N + n; - const uint8_t val = GetPackedElement(src_packed, src_index, total_elements, bits); + const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); - const size_t dst_elem = SafeInt(n) * k_blocks + kb; - SetPackedElement(dst, dst_elem, val, bits); + const size_t dst_byte = SafeInt(n) * zp_bytes_per_n + kb / 2; + if (kb % 2 == 0) { + dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); + } else { + dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); + } } } } -// Returns the Cast node's target element type (the "to" attribute), or nullopt if invalid. -std::optional GetCastToType(const Node& cast_node) { - const auto* to_attr = graph_utils::GetNodeAttribute(cast_node, "to"); - if (!to_attr) return std::nullopt; - return static_cast(to_attr->i()); -} - -bool IsFloatOrFloat16(int32_t dt) { - return dt == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || - dt == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; -} - // --------------------------------------------------------------------------- // Match structs // --------------------------------------------------------------------------- struct FusionMatch { NodeIndex matmul_idx; - std::optional weight_cast_idx; // Cast on weight path (between Transpose and MatMul) + std::optional cast_idx; NodeIndex transpose_idx; NodeIndex reshape_idx; NodeIndex dq_idx; - int64_t bits; - std::optional input_a_cast_idx; // Cast on input A - std::optional output_cast_idx; // Cast on MatMul output - int32_t effective_dt_a; // T1 for MatMulNBits (scale type) }; struct DirectDQMatch { NodeIndex matmul_idx; NodeIndex dq_idx; - int64_t bits; - std::optional weight_cast_idx; // Cast on weight path (between DQ and MatMul) - std::optional input_a_cast_idx; // Cast on input A - std::optional output_cast_idx; // Cast on MatMul output - int32_t effective_dt_a; // T1 for MatMulNBits (scale type) }; // --------------------------------------------------------------------------- @@ -245,7 +188,6 @@ bool ValidateGemmForFusion(const Node& gemm_node, int64_t N) { // --------------------------------------------------------------------------- // Pattern 1 matching: DQ -> Reshape -> Transpose -> [Cast] -> MatMul/Gemm -// With optional Cast on input A and/or MatMul output for FP16 models. // --------------------------------------------------------------------------- std::vector CollectReshapeTransposeMatches( @@ -263,13 +205,12 @@ std::vector CollectReshapeTransposeMatches( const auto& mm_inputs = node->InputDefs(); if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; - // Trace weight path: MatMul input B <- [Cast] <- Transpose <- Reshape <- DQ - const Node* weight_cast_node = nullptr; + const Node* cast_node = nullptr; const Node* transpose_node = graph.GetProducerNode(mm_inputs[1]->Name()); if (transpose_node && transpose_node->OpType() == "Cast") { - weight_cast_node = transpose_node; - if (weight_cast_node->GetOutputEdgesCount() != 1) continue; - const auto& cast_inputs = weight_cast_node->InputDefs(); + cast_node = transpose_node; + if (cast_node->GetOutputEdgesCount() != 1) continue; + const auto& cast_inputs = cast_node->InputDefs(); if (cast_inputs.empty() || !cast_inputs[0] || !cast_inputs[0]->Exists()) continue; transpose_node = graph.GetProducerNode(cast_inputs[0]->Name()); } @@ -302,14 +243,11 @@ std::vector CollectReshapeTransposeMatches( if (block_size < 16 || ((block_size - 1) & block_size)) continue; } - // Validate weight type: must be a supported quantized type const auto* weight_arg = dq_node->InputDefs()[0]; if (!weight_arg || !weight_arg->Exists()) continue; const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); if (!weight_const_tp) continue; - const auto weight_qtype = GetQuantTypeInfo(weight_const_tp->data_type()); - if (!weight_qtype) continue; - const int64_t bits = weight_qtype->bits; + if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; if (weight_const_tp->dims_size() != 3) continue; const int64_t N = weight_const_tp->dims(0); const int64_t blocks = weight_const_tp->dims(1); @@ -318,59 +256,18 @@ std::vector CollectReshapeTransposeMatches( if (bs_dim != block_size) continue; const int64_t K = SafeInt(blocks) * bs_dim; - // Scale type determines the effective T1 for MatMulNBits const auto* scale_arg = dq_node->InputDefs()[1]; if (!scale_arg || !scale_arg->Exists()) continue; const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); if (!scale_const_tp) continue; int32_t dt_scale = scale_const_tp->data_type(); - if (!IsFloatOrFloat16(dt_scale)) continue; + if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; - // Check input A type, looking through optional Cast const auto* a_arg = mm_inputs[0]; if (!a_arg || !a_arg->TypeAsProto()) continue; int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); - - const Node* input_a_cast_node = nullptr; - int32_t effective_dt_a = dt_a; - - if (dt_a != dt_scale) { - // Check if input A is produced by a Cast from dt_scale - const Node* a_producer = graph.GetProducerNode(a_arg->Name()); - if (a_producer && a_producer->OpType() == "Cast") { - const auto cast_to = GetCastToType(*a_producer); - if (cast_to && *cast_to == dt_a) { - const auto* cast_in = a_producer->InputDefs().empty() ? nullptr : a_producer->InputDefs()[0]; - if (cast_in && cast_in->TypeAsProto()) { - int32_t dt_cast_in = cast_in->TypeAsProto()->tensor_type().elem_type(); - if (dt_cast_in == dt_scale && a_producer->GetOutputEdgesCount() == 1) { - input_a_cast_node = a_producer; - effective_dt_a = dt_scale; - } - } - } - } - if (effective_dt_a != dt_scale) continue; - } - - // Validate weight-path Cast: must cast to dt_a (the MatMul compute type) - if (weight_cast_node) { - const auto cast_to = GetCastToType(*weight_cast_node); - if (!cast_to || *cast_to != dt_a) continue; - } - - // Check for Cast on MatMul output - const Node* output_cast_node = nullptr; - if (node->GetOutputEdgesCount() == 1) { - const auto edge = node->OutputEdgesBegin(); - const Node& consumer = edge->GetNode(); - if (consumer.OpType() == "Cast") { - const auto cast_to = GetCastToType(consumer); - if (cast_to && *cast_to == dt_scale && consumer.GetOutputEdgesCount() >= 1) { - output_cast_node = &consumer; - } - } - } + if (dt_a != dt_scale) continue; const auto* reshape_shape_arg = reshape_node->InputDefs().size() > 1 ? reshape_node->InputDefs()[1] : nullptr; @@ -440,35 +337,37 @@ std::vector CollectReshapeTransposeMatches( if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; - // Validate zero-point type matches weight type + if (cast_node) { + const auto* cast_in = cast_node->InputDefs().empty() ? nullptr : cast_node->InputDefs()[0]; + const auto* cast_out = cast_node->OutputDefs().empty() ? nullptr : cast_node->OutputDefs()[0]; + if (!cast_in || !cast_out || !cast_in->TypeAsProto() || !cast_out->TypeAsProto()) continue; + if (cast_in->TypeAsProto()->tensor_type().elem_type() != + cast_out->TypeAsProto()->tensor_type().elem_type()) { + continue; + } + } + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; bool has_zp = zp_arg && zp_arg->Exists(); if (has_zp) { const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); - if (!zp_const_tp) continue; - const auto zp_qtype = GetQuantTypeInfo(zp_const_tp->data_type()); - if (!zp_qtype || zp_qtype->bits != bits) continue; + if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; } LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched pattern at MatMul node '" - << node->Name() << "' (bits=" << bits << ")"; + << node->Name() << "'"; matches.push_back({node->Index(), - weight_cast_node ? std::optional(weight_cast_node->Index()) : std::nullopt, + cast_node ? std::optional(cast_node->Index()) : std::nullopt, transpose_node->Index(), - reshape_node->Index(), dq_node->Index(), - bits, - input_a_cast_node ? std::optional(input_a_cast_node->Index()) : std::nullopt, - output_cast_node ? std::optional(output_cast_node->Index()) : std::nullopt, - dt_scale}); + reshape_node->Index(), dq_node->Index()}); } return matches; } // --------------------------------------------------------------------------- -// Pattern 2 matching: direct DQ(axis=0, 2D) -> [Cast] -> MatMul/Gemm -// With optional Cast on input A and/or MatMul output for FP16 models. +// Pattern 2 matching: direct DQ(axis=0, 2D UINT4) -> MatMul/Gemm // --------------------------------------------------------------------------- std::vector CollectDirectDQMatches( @@ -488,17 +387,7 @@ std::vector CollectDirectDQMatches( const auto& mm_inputs = node->InputDefs(); if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; - // Trace weight path: MatMul input B <- [Cast] <- DQ - const Node* weight_cast_node = nullptr; const Node* dq_node = graph.GetProducerNode(mm_inputs[1]->Name()); - if (dq_node && dq_node->OpType() == "Cast") { - weight_cast_node = dq_node; - if (weight_cast_node->GetOutputEdgesCount() != 1) continue; - const auto& cast_inputs = weight_cast_node->InputDefs(); - if (cast_inputs.empty() || !cast_inputs[0] || !cast_inputs[0]->Exists()) continue; - dq_node = graph.GetProducerNode(cast_inputs[0]->Name()); - } - if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue; if (dq_node->GetOutputEdgesCount() != 1) continue; @@ -519,9 +408,7 @@ std::vector CollectDirectDQMatches( if (!weight_arg || !weight_arg->Exists()) continue; const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); if (!weight_const_tp) continue; - const auto weight_qtype = GetQuantTypeInfo(weight_const_tp->data_type()); - if (!weight_qtype) continue; - const int64_t bits = weight_qtype->bits; + if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; if (weight_const_tp->dims_size() != 2) continue; const int64_t K = weight_const_tp->dims(0); const int64_t N = weight_const_tp->dims(1); @@ -533,74 +420,28 @@ std::vector CollectDirectDQMatches( const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); if (!scale_const_tp) continue; int32_t dt_scale = scale_const_tp->data_type(); - if (!IsFloatOrFloat16(dt_scale)) continue; + if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; if (!HasRank2Shape(*scale_const_tp, k_blocks, N)) continue; - // Check input A type, looking through optional Cast const auto* a_arg = mm_inputs[0]; if (!a_arg || !a_arg->TypeAsProto()) continue; int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); - - const Node* input_a_cast_node = nullptr; - int32_t effective_dt_a = dt_a; - - if (dt_a != dt_scale) { - const Node* a_producer = graph.GetProducerNode(a_arg->Name()); - if (a_producer && a_producer->OpType() == "Cast") { - const auto cast_to = GetCastToType(*a_producer); - if (cast_to && *cast_to == dt_a) { - const auto* cast_in = a_producer->InputDefs().empty() ? nullptr : a_producer->InputDefs()[0]; - if (cast_in && cast_in->TypeAsProto()) { - int32_t dt_cast_in = cast_in->TypeAsProto()->tensor_type().elem_type(); - if (dt_cast_in == dt_scale && a_producer->GetOutputEdgesCount() == 1) { - input_a_cast_node = a_producer; - effective_dt_a = dt_scale; - } - } - } - } - if (effective_dt_a != dt_scale) continue; - } - - // Validate weight-path Cast - if (weight_cast_node) { - const auto cast_to = GetCastToType(*weight_cast_node); - if (!cast_to || *cast_to != dt_a) continue; - } - - // Check for Cast on MatMul output - const Node* output_cast_node = nullptr; - if (node->GetOutputEdgesCount() == 1) { - const auto edge = node->OutputEdgesBegin(); - const Node& consumer = edge->GetNode(); - if (consumer.OpType() == "Cast") { - const auto cast_to = GetCastToType(consumer); - if (cast_to && *cast_to == dt_scale && consumer.GetOutputEdgesCount() >= 1) { - output_cast_node = &consumer; - } - } - } + if (dt_a != dt_scale) continue; const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; bool has_zp = zp_arg && zp_arg->Exists(); if (has_zp) { const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); - if (!zp_const_tp) continue; - const auto zp_qtype = GetQuantTypeInfo(zp_const_tp->data_type()); - if (!zp_qtype || zp_qtype->bits != bits) continue; + if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; if (!HasRank2Shape(*zp_const_tp, k_blocks, N)) continue; } if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched direct DQ->MatMul pattern at node '" - << node->Name() << "' (K=" << K << ", N=" << N - << ", block_size=" << block_size << ", bits=" << bits << ")"; - direct_matches.push_back({node->Index(), dq_node->Index(), bits, - weight_cast_node ? std::optional(weight_cast_node->Index()) : std::nullopt, - input_a_cast_node ? std::optional(input_a_cast_node->Index()) : std::nullopt, - output_cast_node ? std::optional(output_cast_node->Index()) : std::nullopt, - dt_scale}); + << node->Name() << "' (K=" << K << ", N=" << N << ", block_size=" << block_size << ")"; + direct_matches.push_back({node->Index(), dq_node->Index()}); } return direct_matches; @@ -618,22 +459,15 @@ void ApplyReshapeTransposeFusions( const logging::Logger& logger) { for (const auto& match : matches) { const Node* mm_node = graph.GetNode(match.matmul_idx); - const Node* weight_cast_node = match.weight_cast_idx ? graph.GetNode(*match.weight_cast_idx) : nullptr; + const Node* cast_node = match.cast_idx ? graph.GetNode(*match.cast_idx) : nullptr; const Node* tp_node = graph.GetNode(match.transpose_idx); const Node* dq_node = graph.GetNode(match.dq_idx); const Node* reshape_node = graph.GetNode(match.reshape_idx); - const Node* input_a_cast_node = match.input_a_cast_idx ? graph.GetNode(*match.input_a_cast_idx) : nullptr; - const Node* output_cast_node = match.output_cast_idx ? graph.GetNode(*match.output_cast_idx) : nullptr; if (!mm_node || !tp_node || !dq_node || !reshape_node || - (match.weight_cast_idx && !weight_cast_node) || - (match.input_a_cast_idx && !input_a_cast_node) || - (match.output_cast_idx && !output_cast_node)) { + (match.cast_idx && !cast_node)) { continue; } - const int64_t bits = match.bits; - const int32_t effective_dt_a = match.effective_dt_a; - const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; @@ -651,8 +485,8 @@ void ApplyReshapeTransposeFusions( if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; } - const auto weight_qtype = GetQuantTypeInfo(weight_tp->data_type()); - if (!weight_qtype || weight_qtype->bits != bits || weight_tp->dims_size() != 3) { + if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || + weight_tp->dims_size() != 3) { continue; } @@ -661,11 +495,12 @@ void ApplyReshapeTransposeFusions( const int64_t bs_dim = weight_tp->dims(2); if (N <= 0 || quant_num <= 0 || bs_dim <= 0 || bs_dim != block_size) continue; const int64_t K = SafeInt(quant_num) * bs_dim; - const int64_t blob_bytes = PackedByteSize(block_size, bits); + const int64_t blob_bytes = (block_size + 1) / 2; Initializer weight_src(graph, *weight_tp, graph.ModelPath()); Initializer scale_src(graph, *scale_tp, graph.ModelPath()); - if (!IsFloatOrFloat16(scale_src.data_type())) { + if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { continue; } @@ -688,11 +523,10 @@ void ApplyReshapeTransposeFusions( std::string zp_dst_name; std::optional zp_dst; - const int64_t zp_packed_size = SafeInt(N) * PackedByteSize(quant_num, bits); + const int64_t zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); - bool elide_default_zp = false; + bool elide_default_uint4_zp8_input = false; std::optional zp_src; - const uint8_t mnb_default_zp = static_cast(1 << (bits - 1)); // 2^(bits-1) const auto weight_bytes = weight_src.DataAsByteSpan(); if (weight_bytes.size() != static_cast(weight_dst.SizeInBytes())) continue; @@ -708,29 +542,30 @@ void ApplyReshapeTransposeFusions( if (zp_tp) { zp_src.emplace(graph, *zp_tp, graph.ModelPath()); - const auto zp_qtype = GetQuantTypeInfo(zp_src->data_type()); - if (!zp_qtype || zp_qtype->bits != bits) continue; + if (zp_src->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; if (zp_src->size() != static_cast(N * quant_num)) continue; - if (IsUniformPackedValue(*zp_src, mnb_default_zp, bits)) { - elide_default_zp = true; + const bool is_default_uint4_8 = + IsUniformPackedUint4Value(*zp_src, /*expected_nibble*/ 8); + if (is_default_uint4_8) { + elide_default_uint4_zp8_input = true; } else { zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); - zp_dst = Tensor(uint8_type, TensorShape{zp_packed_size}, cpu_allocator); - PackRows(*zp_src, N, quant_num, bits, zp_dst->MutableData()); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + PackUint4Rows(*zp_src, N, quant_num, zp_dst->MutableData()); } } else { - // DequantizeLinear default zero-point is 0, while MatMulNBits - // default is 2^(bits-1). Emit explicit zeros to preserve semantics. + // DequantizeLinear default zero-point for uint4 is 0, while MatMulNBits + // default is 8. Emit explicit zeros to preserve semantics. zp_dst_name = graph.GenerateNodeArgName("fused_DQ_zp_mnb"); - zp_dst = Tensor(uint8_type, TensorShape{zp_packed_size}, cpu_allocator); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); } auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); std::optional zp_mnb_tp; - if (zp_dst && !elide_default_zp) { + if (zp_dst && !elide_default_uint4_zp8_input) { zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } @@ -738,16 +573,11 @@ void ApplyReshapeTransposeFusions( utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); - utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); - // Determine input A: use pre-Cast value if input A Cast is being removed - NodeArg* mnb_input_a = input_a_cast_node - ? const_cast(input_a_cast_node->InputDefs()[0]) - : const_cast(mm_node->InputDefs()[0]); - std::vector mnb_inputs; - mnb_inputs.push_back(mnb_input_a); + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { @@ -768,20 +598,8 @@ void ApplyReshapeTransposeFusions( fused_with_bias = true; } - // Determine output: if output Cast exists, take over its output; otherwise MatMul's output std::vector mnb_outputs; - if (output_cast_node) { - mnb_outputs.push_back(const_cast(output_cast_node->OutputDefs()[0])); - } else if (effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { - // MatMulNBits outputs T1 (=effective_dt_a) but consumers expect the original type. - // Create a new intermediate output for MatMulNBits and insert a Cast after it. - auto& mnb_out_arg = graph.GetOrCreateNodeArg( - graph.GenerateNodeArgName("mnb_out"), - nullptr); - mnb_outputs.push_back(&mnb_out_arg); - } else { - mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); - } + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); auto& mnb_node = graph.AddNode( graph.GenerateNodeName("DQFusedMatMulNBits"), @@ -790,37 +608,12 @@ void ApplyReshapeTransposeFusions( mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); - // If we need a Cast after MatMulNBits (no output_cast_node to absorb the type difference) - if (!output_cast_node && - effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { - int32_t original_dt = mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - NodeAttributes cast_attrs; - utils::SetNodeAttribute(utils::MakeAttribute("to", static_cast(original_dt)), cast_attrs); - graph.AddNode( - graph.GenerateNodeName("DQFusedMatMulNBits_Cast"), - "Cast", "Cast MNB output to original type", - {mnb_outputs[0]}, - {const_cast(mm_node->OutputDefs()[0])}, - &cast_attrs); - } - - // Remove nodes in reverse dependency order - if (output_cast_node) { - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.output_cast_idx.value())); - graph.RemoveNode(match.output_cast_idx.value()); - } - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); graph.RemoveNode(match.matmul_idx); - if (input_a_cast_node) { - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.input_a_cast_idx.value())); - graph.RemoveNode(match.input_a_cast_idx.value()); - } - - if (match.weight_cast_idx && graph.GetNode(*match.weight_cast_idx)) { - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(*match.weight_cast_idx)); - graph.RemoveNode(*match.weight_cast_idx); + if (match.cast_idx && graph.GetNode(*match.cast_idx)) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(*match.cast_idx)); + graph.RemoveNode(*match.cast_idx); } graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.transpose_idx)); @@ -833,19 +626,16 @@ void ApplyReshapeTransposeFusions( graph.RemoveNode(match.dq_idx); LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused DQ+Reshape+Transpose" - << (match.weight_cast_idx ? "+Cast" : "") + << (match.cast_idx ? "+Cast" : "") << "+MatMul/Gemm -> MatMulNBits" - << " (bits=" << bits << ")" << (fused_with_bias ? " (bias preserved)" : "") - << (elide_default_zp ? " (default zp elided)" : "") - << (input_a_cast_node ? " (input Cast removed)" : "") - << (output_cast_node ? " (output Cast removed)" : ""); + << (elide_default_uint4_zp8_input ? " (default UINT4 zp8 elided)" : ""); modified = true; } } // --------------------------------------------------------------------------- -// Pattern 2 rewriting: direct DQ(axis=0) + [Cast] + MatMul/Gemm -> MatMulNBits +// Pattern 2 rewriting: direct DQ(axis=0) + MatMul/Gemm -> MatMulNBits // --------------------------------------------------------------------------- void ApplyDirectDQFusions( @@ -857,18 +647,7 @@ void ApplyDirectDQFusions( for (const auto& match : matches) { const Node* mm_node = graph.GetNode(match.matmul_idx); const Node* dq_node = graph.GetNode(match.dq_idx); - const Node* weight_cast_node = match.weight_cast_idx ? graph.GetNode(*match.weight_cast_idx) : nullptr; - const Node* input_a_cast_node = match.input_a_cast_idx ? graph.GetNode(*match.input_a_cast_idx) : nullptr; - const Node* output_cast_node = match.output_cast_idx ? graph.GetNode(*match.output_cast_idx) : nullptr; - if (!mm_node || !dq_node || - (match.weight_cast_idx && !weight_cast_node) || - (match.input_a_cast_idx && !input_a_cast_node) || - (match.output_cast_idx && !output_cast_node)) { - continue; - } - - const int64_t bits = match.bits; - const int32_t effective_dt_a = match.effective_dt_a; + if (!mm_node || !dq_node) continue; const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -887,14 +666,14 @@ void ApplyDirectDQFusions( if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; } - const auto weight_qtype = GetQuantTypeInfo(weight_tp->data_type()); - if (!weight_qtype || weight_qtype->bits != bits || weight_tp->dims_size() != 2) continue; + if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || + weight_tp->dims_size() != 2) continue; const int64_t K = weight_tp->dims(0); const int64_t N = weight_tp->dims(1); if (K <= 0 || N <= 0 || block_size <= 0 || K % block_size != 0) continue; const int64_t k_blocks = K / block_size; - const int64_t blob_bytes = block_size * bits / 8; + const int64_t blob_bytes = block_size / 2; if (!HasRank2Shape(*scale_tp, k_blocks, N)) continue; if (zp_tp && !HasRank2Shape(*zp_tp, k_blocks, N)) continue; @@ -902,7 +681,8 @@ void ApplyDirectDQFusions( const size_t required_weight_bytes = SafeInt(N) * k_blocks * blob_bytes; if (weight_src.DataAsByteSpan().size() < required_weight_bytes) continue; Initializer scale_src(graph, *scale_tp, graph.ModelPath()); - if (!IsFloatOrFloat16(scale_src.data_type())) continue; + if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum( ONNX_NAMESPACE::TensorProto_DataType_UINT8) @@ -914,7 +694,7 @@ void ApplyDirectDQFusions( auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb"); auto weight_dst = Tensor(uint8_type, TensorShape{N, k_blocks, blob_bytes}, cpu_allocator); - TransposePackWeightsAxis0(weight_src.DataAsByteSpan().data(), K, N, block_size, bits, + TransposePackWeightsAxis0(weight_src.DataAsByteSpan().data(), K, N, block_size, weight_dst.MutableData()); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb"); @@ -938,27 +718,25 @@ void ApplyDirectDQFusions( std::string zp_dst_name; std::optional zp_dst; - const int64_t zp_bytes_total = SafeInt(N) * PackedByteSize(k_blocks, bits); + const int64_t zp_bytes_total = SafeInt(N) * ((k_blocks + 1) / 2); bool elide_zp = false; - const uint8_t mnb_default_zp = static_cast(1 << (bits - 1)); if (zp_tp) { Initializer zp_src(graph, *zp_tp, graph.ModelPath()); - const auto zp_qtype = GetQuantTypeInfo(zp_src.data_type()); - if (!zp_qtype || zp_qtype->bits != bits) continue; + if (zp_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; if (zp_src.size() != static_cast(k_blocks * N)) continue; - if (IsUniformPackedValue(zp_src, mnb_default_zp, bits)) { + if (IsUniformPackedUint4Value(zp_src, 8)) { elide_zp = true; } else { zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); - TransposePackZPAxis0(zp_src.DataAsByteSpan().data(), k_blocks, N, bits, + TransposePackZPAxis0(zp_src.DataAsByteSpan().data(), k_blocks, N, zp_dst->MutableData()); } } else { - // DQ default ZP is 0, MatMulNBits default is 2^(bits-1). Emit explicit zeros. + // DQ default ZP for UINT4 is 0, MatMulNBits default is 8. Emit explicit zeros. zp_dst_name = graph.GenerateNodeArgName("direct_DQ_zp_mnb"); zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); @@ -975,16 +753,11 @@ void ApplyDirectDQFusions( utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); - utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); - // Determine input A: use pre-Cast value if input A Cast is being removed - NodeArg* mnb_input_a = input_a_cast_node - ? const_cast(input_a_cast_node->InputDefs()[0]) - : const_cast(mm_node->InputDefs()[0]); - std::vector mnb_inputs; - mnb_inputs.push_back(mnb_input_a); + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { @@ -1004,18 +777,8 @@ void ApplyDirectDQFusions( fused_with_bias = true; } - // Determine output: if output Cast exists, take over its output; otherwise MatMul's output std::vector mnb_outputs; - if (output_cast_node) { - mnb_outputs.push_back(const_cast(output_cast_node->OutputDefs()[0])); - } else if (effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { - auto& mnb_out_arg = graph.GetOrCreateNodeArg( - graph.GenerateNodeArgName("mnb_out"), - nullptr); - mnb_outputs.push_back(&mnb_out_arg); - } else { - mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); - } + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); auto& mnb_node = graph.AddNode( graph.GenerateNodeName("DirectDQFusedMatMulNBits"), @@ -1024,49 +787,16 @@ void ApplyDirectDQFusions( mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); - // If we need a Cast after MatMulNBits - if (!output_cast_node && - effective_dt_a != mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { - int32_t original_dt = mm_node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - NodeAttributes cast_attrs; - utils::SetNodeAttribute(utils::MakeAttribute("to", static_cast(original_dt)), cast_attrs); - graph.AddNode( - graph.GenerateNodeName("DirectDQFusedMatMulNBits_Cast"), - "Cast", "Cast MNB output to original type", - {mnb_outputs[0]}, - {const_cast(mm_node->OutputDefs()[0])}, - &cast_attrs); - } - - // Remove nodes in reverse dependency order - if (output_cast_node) { - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.output_cast_idx.value())); - graph.RemoveNode(match.output_cast_idx.value()); - } - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); graph.RemoveNode(match.matmul_idx); - if (input_a_cast_node) { - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.input_a_cast_idx.value())); - graph.RemoveNode(match.input_a_cast_idx.value()); - } - - if (weight_cast_node) { - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.weight_cast_idx.value())); - graph.RemoveNode(match.weight_cast_idx.value()); - } - graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx)); graph.RemoveNode(match.dq_idx); LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused direct DQ(axis=0)+MatMul/Gemm -> MatMulNBits" - << " (K=" << K << ", N=" << N << ", block_size=" << block_size - << ", bits=" << bits << ")" + << " (K=" << K << ", N=" << N << ", block_size=" << block_size << ")" << (fused_with_bias ? " (bias preserved)" : "") - << (elide_zp ? " (default zp elided)" : "") - << (input_a_cast_node ? " (input Cast removed)" : "") - << (output_cast_node ? " (output Cast removed)" : ""); + << (elide_zp ? " (default UINT4 zp8 elided)" : ""); modified = true; } } diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h index b5d3e60f1be19..97c0debd760c0 100644 --- a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h @@ -13,14 +13,10 @@ namespace onnxruntime { // Fuses DequantizeLinear chains back into a single MatMulNBits contrib op. // -// Supported patterns (weight types: UINT2, INT2, UINT4, INT4, UINT8): -// Pattern 1: DQ(3D, axis=2) -> Reshape(2D) -> Transpose([1,0]) +// Supported patterns: +// Pattern 1: DQ(3D, UINT4, axis=2) -> Reshape(2D) -> Transpose([1,0]) // -> [optional Cast] -> MatMul/Gemm => MatMulNBits -// Pattern 2: DQ(2D, axis=0) -> [optional Cast] -> MatMul/Gemm => MatMulNBits -// -// FP16 Cast handling: Cast nodes on input A (FP16→FP32), the weight path -// (FP16→FP32), and output (FP32→FP16) are absorbed into the fusion so that -// MatMulNBits operates directly on FP16 inputs/outputs. +// Pattern 2: DQ(2D, UINT4, axis=0) -> MatMul/Gemm => MatMulNBits // // These patterns are produced when a quantized model goes through external // toolchains that lower MatMulNBits to DQ + reshape/transpose + MatMul From 9e131cf6a8e9d1e78e3a844077de91a23b05eae9 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 10 Mar 2026 21:55:11 +0000 Subject: [PATCH 3/7] mnb fusion rules extension --- onnxruntime/core/mlas/lib/q4_dq.cpp | 189 +++++++++- .../selectors_actions/qdq_actions.cc | 344 +++++++++++++++--- .../selectors_actions/qdq_actions.h | 14 + .../qdq_selector_action_transformer.cc | 19 + .../selectors_actions/qdq_selectors.cc | 145 +++++++- .../selectors_actions/qdq_selectors.h | 21 ++ .../qdq_matmulnbits_transformer_test.cc | 201 +++++++++- 7 files changed, 870 insertions(+), 63 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index fbbf4005ae4a5..f9019009ae644 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -663,6 +663,9 @@ struct BlockwiseQDQQuantizer { return (val >> (idx << 1)) & 0x3; } else if constexpr (qbits == 4) { return (val >> (idx << 2)) & 0xF; + } else if constexpr (qbits == 8) { + (void)idx; + return val; } } @@ -674,6 +677,10 @@ struct BlockwiseQDQQuantizer { } else if constexpr (qbits == 4) { auto shift = idx << 2; return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } else if constexpr (qbits == 8) { + (void)idx; + (void)dst; + return val; } } @@ -813,21 +820,87 @@ struct BlockwiseQDQQuantizer { src_zero_points || signed_quant || dst_zero_points, "Unsigned quant types without zero points must allocate zero points with value 0." ); - // Must avoid multiple thread write to a single byte, which means the starting index - // of a thread block must be even. To achieve that, we need to customize the thread - // block size based on the parity of columns. - if (columns & 1) { - TransposeColumnWiseQuantizedPackUnaligned( - src_weights, src_scales, src_zero_points, - dst_weights, dst_scales, dst_zero_points, - rows, columns, quant_block_size, thread_pool + + if constexpr (qbits == 8) { + // 8-bit: each element is one byte, no sub-byte packing needed. + // Simple byte-level transpose from [rows, columns] to [columns, k_blocks, block_size]. + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto dst_bytes_per_quant_blk = quant_block_size; // 8 bits = 1 byte per element + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + + // Transpose weights: src [rows, columns] -> dst [columns, k_blocks, block_size] + MlasTryBatchParallel( + thread_pool, static_cast(row_quant_blk_num * columns), + [&](ptrdiff_t thread_blk_idx) { + auto row_blk = static_cast(thread_blk_idx / columns); + auto col = static_cast(thread_blk_idx % columns); + + auto src_row_start = row_blk * quant_block_size; + auto src_row_end = std::min(src_row_start + quant_block_size, rows); + + auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk; + for (auto r = src_row_start; r < src_row_end; ++r) { + auto src_val = src_weights[r * columns + col]; + if constexpr (signed_quant) { + src_val ^= 0x80; // INT8 -> UINT8: add 128 + } + dst_weights[dst_base + (r - src_row_start)] = src_val; + } + // Zero-pad remaining bytes in the last block if rows % block_size != 0 + for (auto r = src_row_end - src_row_start; r < quant_block_size; ++r) { + dst_weights[dst_base + r] = signed_quant ? 0x80 : 0; + } + } ); - } else { - TransposeColumnWiseQuantizedPackAligned( - src_weights, src_scales, src_zero_points, - dst_weights, dst_scales, dst_zero_points, - rows, columns, quant_block_size, thread_pool + + // Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks] + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col) { + auto src_idx = static_cast(col); + auto dst_idx = static_cast(col) * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + } ); + + // Transpose zero points: src [k_blocks, columns] -> dst [columns, k_blocks] + // For 8-bit, zero points are byte-aligned (1 byte each), no packing needed. + if (src_zero_points && dst_zero_points) { + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col) { + auto src_idx = static_cast(col); + auto dst_idx = static_cast(col) * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + auto zp = src_zero_points[src_idx]; + if constexpr (signed_quant) { + zp ^= 0x80; // INT8 -> UINT8 + } + dst_zero_points[dst_idx] = zp; + } + } + ); + } + } else { + // Sub-byte types (2-bit, 4-bit): use packing-aware transpose paths. + // Must avoid multiple thread write to a single byte, which means the starting index + // of a thread block must be even. To achieve that, we need to customize the thread + // block size based on the parity of columns. + if (columns & 1) { + TransposeColumnWiseQuantizedPackUnaligned( + src_weights, src_scales, src_zero_points, + dst_weights, dst_scales, dst_zero_points, + rows, columns, quant_block_size, thread_pool + ); + } else { + TransposeColumnWiseQuantizedPackAligned( + src_weights, src_scales, src_zero_points, + dst_weights, dst_scales, dst_zero_points, + rows, columns, quant_block_size, thread_pool + ); + } } } @@ -2184,3 +2257,93 @@ MlasQDQTransposeBlockwiseQuantized( int quant_block_size, MLAS_THREADPOOL* thread_pool ); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index dddf80252f727..1ca38530003d1 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -8,12 +8,41 @@ #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" +#include "core/graph/graph_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace QDQ { +namespace { +// Derive MatMulNBits 'bits' attribute from the DQ weight element type. +int64_t DQWeightBits(int32_t dt_weight) { + using TensorProto = ONNX_NAMESPACE::TensorProto; + switch (dt_weight) { + case TensorProto::INT2: + case TensorProto::UINT2: + return 2; + case TensorProto::INT4: + case TensorProto::UINT4: + return 4; + case TensorProto::INT8: + case TensorProto::UINT8: + return 8; + default: + ORT_THROW("Unsupported DQ weight type for MatMulNBits fusion: ", dt_weight); + } +} + +// Whether the DQ weight type is signed (requires zero-point offset conversion). +bool IsDQWeightSigned(int32_t dt_weight) { + using TensorProto = ONNX_NAMESPACE::TensorProto; + return dt_weight == TensorProto::INT2 || + dt_weight == TensorProto::INT4 || + dt_weight == TensorProto::INT8; +} +} // namespace + namespace { using NTO = NodesToOptimize; @@ -306,8 +335,8 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); - // currently only 4bits is supported. In the future, derive bits from DQ's weight type. - utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); + int32_t dt_weight = dq_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + utils::SetNodeAttribute(utils::MakeAttribute("bits", DQWeightBits(dt_weight)), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); return extra_attributes; @@ -339,8 +368,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, auto K = weight_arg->Shape()->dim(0).dim_value(); auto N = weight_arg->Shape()->dim(1).dim_value(); auto block_size = attrs.at("block_size").i(); + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + auto bits = DQWeightBits(dt_weight); auto quant_num = (K + block_size - 1) / block_size; - auto blob_bytes = (block_size + 1) / 2; + auto blob_bytes = (block_size * bits + 7) / 8; // Unfortunately iterating the source data is complicated, the data maybe in // external file, a raw buffer, or a repeated field depending on the data @@ -364,7 +395,7 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, cpu_allocator); std::string zp_dst_name; std::optional zp_dst; - auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); + auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); // packed zp bytes per column if (zp_tensor_proto) { zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); @@ -372,7 +403,8 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); - } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + } else if (!IsDQWeightSigned(dt_weight)) { + // Unsigned quant types without explicit zero points need a default zero-point buffer of 0. zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); zp_dst = Tensor(uint8_type, TensorShape{zp_size}, @@ -380,85 +412,309 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); } - if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { - MlasQDQTransposeBlockwiseQuantized( + // Helper lambda to dispatch the MLAS transpose for a given scale type. + auto transpose = [&](auto* scale_data, auto* scale_dst_data) { + using ScaleType = std::remove_pointer_t; + bool is_signed = IsDQWeightSigned(dt_weight); + auto call = [&]() { + MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), - scale_src.data(), + scale_data, zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.MutableData(), - scale_dst.MutableData(), + scale_dst_data, zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), static_cast(block_size), intra_op_thread_pool_); + }; + + // Dispatch based on bits and signedness. Template parameters must be compile-time constants. + if (bits == 2) { + if (is_signed) { + call.template operator()<2, true>(); + } else { + call.template operator()<2, false>(); + } + } else if (bits == 4) { + if (is_signed) { + call.template operator()<4, true>(); + } else { + call.template operator()<4, false>(); + } } else { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst.MutableData(), - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); + if (is_signed) { + call.template operator()<8, true>(); + } else { + call.template operator()<8, false>(); + } } + }; + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + transpose(scale_src.data(), scale_dst.MutableData()); } else { - if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst.MutableData(), - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); + transpose(scale_src.data(), scale_dst.MutableData()); + } - } else { - MlasQDQTransposeBlockwiseQuantized( + auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); + auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); + std::optional zp_T_tp; + + if (zp_dst) { + zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); + } + + auto& input_defs = replacement_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_T_tp, std::move(weight_dst))); + replacement_node.MutableInputArgsCount().push_back(1); + + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_T_tp, std::move(scale_dst))); + replacement_node.MutableInputArgsCount().push_back(1); + + if (zp_T_tp) { + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_T_tp.value(), std::move(*zp_dst))); + replacement_node.MutableInputArgsCount().push_back(1); + } + + return Status::OK(); +} + +DQCastMatMulToMatMulNBitsAction::DQCastMatMulToMatMulNBitsAction( + int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) + : accuracy_level_{accuracy_level}, + intra_op_thread_pool_{intra_op_thread_pool} { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); +} + +Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { + // Selected nodes layout (from DQCastMatMulToMatMulNBitsSelector): + // Input(0) = DQ node + // Input(1) = Cast on input B (between DQ and MatMul) + // Target() = MatMul node + auto* dq_node = selected_nodes.Input(0); + auto* cast_b_node = selected_nodes.Input(1); + auto& matmul_node = selected_nodes.Target(); + + // --- Get DQ weight/scale/zp info --- + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + const auto& dq_attrs = dq_node->GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), + "Missing required weight: ", weight_arg->Name()); + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), + "Missing required scale: ", scale_arg->Name()); + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = dq_attrs.at("block_size").i(); + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + auto bits = DQWeightBits(dt_weight); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size * bits + 7) / 8; + + // --- Transpose weights/scales/zp via MLAS --- + Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); + + auto cpu_allocator = CPUAllocator::DefaultInstance(); + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); + + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); + auto weight_dst = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + + auto orig_scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_size = (TensorShape{N, quant_num}).Size(); + auto scale_dst = Tensor(orig_scale_type, TensorShape{scale_size}, cpu_allocator); + + std::string zp_dst_name; + std::optional zp_dst; + std::optional zp_src; + auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); + + if (zp_tensor_proto) { + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + } else if (!IsDQWeightSigned(dt_weight)) { + zp_dst_name = graph.GenerateNodeArgName("fused_DQ_Cast_MatMul_zero_point_T"); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); + } + + // MLAS transpose dispatch + auto transpose = [&](auto* scale_data, auto* scale_dst_data) { + using ScaleType = std::remove_pointer_t; + bool is_signed = IsDQWeightSigned(dt_weight); + auto call = [&]() { + MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), - scale_src.data(), + scale_data, zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.MutableData(), - scale_dst.MutableData(), + scale_dst_data, zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), static_cast(block_size), intra_op_thread_pool_); + }; + + if (bits == 2) { + if (is_signed) { + call.template operator()<2, true>(); + } else { + call.template operator()<2, false>(); + } + } else if (bits == 4) { + if (is_signed) { + call.template operator()<4, true>(); + } else { + call.template operator()<4, false>(); + } + } else { + if (is_signed) { + call.template operator()<8, true>(); + } else { + call.template operator()<8, false>(); + } } + }; + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + transpose(scale_src.data(), scale_dst.MutableData()); + } else { + transpose(scale_src.data(), scale_dst.MutableData()); } + // MatMulNBits operates in the DQ scale dtype. + // Always insert Cast on input A (to DQ dtype) and Cast on output (DQ dtype to MatMul output dtype). + // ORT's redundant cast elimination optimizer will clean up unnecessary casts later. + + // Determine DQ output element type (e.g., fp16) + int32_t dq_output_dtype = cast_b_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + // Determine MatMul output element type (e.g., fp32) + int32_t matmul_output_dtype = matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + // Prepare tensor protos for initializers auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); std::optional zp_T_tp; - if (zp_dst) { zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } - auto& input_defs = replacement_node.MutableInputDefs(); + // --- Create fp16 NodeArg for MatMulNBits input A --- + NodeArg* matmul_input_a = matmul_node.MutableInputDefs()[0]; + ONNX_NAMESPACE::TypeProto input_a_fp16_type; + input_a_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype); + if (matmul_input_a->Shape()) { + *input_a_fp16_type.mutable_tensor_type()->mutable_shape() = + matmul_input_a->TypeAsProto()->tensor_type().shape(); + } + auto cast_a_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_input_a_cast"); + NodeArg* input_a_arg = &graph.GetOrCreateNodeArg(cast_a_out_name, &input_a_fp16_type); + + // --- Create fp16 NodeArg for MatMulNBits output --- + ONNX_NAMESPACE::TypeProto output_fp16_type; + output_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype); + if (matmul_node.OutputDefs()[0]->Shape()) { + *output_fp16_type.mutable_tensor_type()->mutable_shape() = + matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().shape(); + } + auto mnb_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_matmulnbits_out"); + NodeArg* mnb_output_arg = &graph.GetOrCreateNodeArg(mnb_out_name, &output_fp16_type); + + // --- Create MatMulNBits node --- + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", N), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), attrs); + + auto& new_node = graph.AddNode( + graph.GenerateNodeName(matmul_node.Name() + "_MatMulNBits"), + "MatMulNBits", + "Fused DQ+Cast+MatMul to MatMulNBits", + {input_a_arg}, + {mnb_output_arg}, + &attrs, + kMSDomain); + + const auto& target_provider = matmul_node.GetExecutionProviderType(); + new_node.SetExecutionProviderType(target_provider.empty() ? kCpuExecutionProvider : target_provider); + + // Add transposed weight, scale, zp to inputs + auto& input_defs = new_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_T_tp, std::move(weight_dst))); - replacement_node.MutableInputArgsCount().push_back(1); + new_node.MutableInputArgsCount().push_back(1); input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_T_tp, std::move(scale_dst))); - replacement_node.MutableInputArgsCount().push_back(1); + new_node.MutableInputArgsCount().push_back(1); if (zp_T_tp) { input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_T_tp.value(), std::move(*zp_dst))); - replacement_node.MutableInputArgsCount().push_back(1); + new_node.MutableInputArgsCount().push_back(1); } + // --- Insert Cast on input A: matmul_input_dtype -> dq_output_dtype --- + { + NodeAttributes cast_attrs; + utils::SetNodeAttribute( + utils::MakeAttribute("to", static_cast(dq_output_dtype)), + cast_attrs); + auto& cast_node = graph.AddNode( + graph.GenerateNodeName(matmul_node.Name() + "_Cast_input_a"), + "Cast", "", + {matmul_input_a}, + {input_a_arg}, + &cast_attrs, + kOnnxDomain); + cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType()); + } + + // --- Insert Cast on output: dq_output_dtype -> matmul_output_dtype --- + { + NodeAttributes cast_attrs; + utils::SetNodeAttribute( + utils::MakeAttribute("to", static_cast(matmul_output_dtype)), + cast_attrs); + auto& cast_node = graph.AddNode( + graph.GenerateNodeName(matmul_node.Name() + "_Cast_output"), + "Cast", "", + {mnb_output_arg}, + {const_cast(matmul_node.OutputDefs()[0])}, + &cast_attrs, + kOnnxDomain); + cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType()); + } + + // --- Remove original nodes --- + auto remove_node = [&graph](Node* node) { + if (node) { + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node->Index()); + } + }; + + remove_node(&matmul_node); + remove_node(cast_b_node); + remove_node(dq_node); + return Status::OK(); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 02a8353707599..e112959cc58da 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -107,6 +107,20 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { concurrency::ThreadPool* intra_op_thread_pool_; }; +// Used together with DQCastMatMulToMatMulNBitsSelector. +// Handles DQ -> Cast(fp16->fp32) -> MatMul fusion to MatMulNBits, +// including optional Cast on input A and output type alignment. +struct DQCastMatMulToMatMulNBitsAction : public Action { + DQCastMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool); + + Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; + + private: + int64_t accuracy_level_; + concurrency::ThreadPool* intra_op_thread_pool_; +}; + struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index d454df3393f2b..af45b63ac1fd7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -316,6 +316,25 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi #else qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); #endif + + // DQ -> Cast(fp16->fp32) -> MatMul pattern. + // Handles FP16 models where Cast nodes are inserted between DQ and MatMul. + const std::string cast_action_name{"DQCastMatMulToMatMulNBits"}; + + std::unique_ptr cast_action = + std::make_unique(qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); + +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr cast_selector = + std::make_unique(providers); + qdq_selector_action_registry.RegisterSelectorAndAction(cast_action_name, + {{"MatMul", {}}}, + std::move(cast_selector), + std::move(cast_action)); +#else + qdq_selector_action_registry.RegisterAction(cast_action_name, std::move(cast_action)); +#endif } void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 05b337d9933fb..fd6b525334046 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -20,11 +20,27 @@ constexpr bool Is16BitIntType(int32_t data_type) { (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); } +constexpr bool Is2BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT2) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT2); +} + constexpr bool Is4BitIntType(int32_t data_type) { return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4) || (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4); } +constexpr bool Is8BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8); +} + +// Returns true if the data type is a sub-byte or byte quantized integer type +// suitable for MatMulNBits fusion (2, 4, or 8 bit). +constexpr bool IsNBitsIntType(int32_t data_type) { + return Is2BitIntType(data_type) || Is4BitIntType(data_type) || Is8BitIntType(data_type); +} + // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -577,7 +593,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod return false; } - if (!Is4BitIntType(dt_weight)) { + if (!IsNBitsIntType(dt_weight)) { return false; } @@ -627,6 +643,133 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod return true; } +std::optional +DQCastMatMulToMatMulNBitsSelector::Select(const GraphViewer& graph_viewer, const Node& node) const { + // Check EP compatibility + const std::string_view node_ep = node.GetExecutionProviderType(); + if (!compatible_providers_.empty() && + std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) { + return std::nullopt; + } + + const auto& graph = graph_viewer.GetGraph(); + + // node must be MatMul + if (node.OpType() != "MatMul") { + return std::nullopt; + } + + if (node.InputDefs().size() < 2) { + return std::nullopt; + } + + // Check input B: must be Cast(fp16->fp32) + const Node* cast_b = graph_viewer.GetProducerNode(node.InputDefs()[1]->Name()); + if (!cast_b || cast_b->OpType() != "Cast") { + return std::nullopt; + } + + const auto& cast_b_attrs = cast_b->GetAttributes(); + auto to_iter = cast_b_attrs.find("to"); + if (to_iter == cast_b_attrs.end() || + to_iter->second.i() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) { + return std::nullopt; + } + + // Cast B input must be fp16 + if (!cast_b->InputDefs()[0]->TypeAsProto() || + cast_b->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return std::nullopt; + } + + // Cast B must have exactly 1 output edge (to MatMul) and not be a graph output + if (!optimizer_utils::CheckOutputEdges(graph, *cast_b, 1)) { + return std::nullopt; + } + + // Cast B's input must come from a DQ node + const Node* dq_node = graph_viewer.GetProducerNode(cast_b->InputDefs()[0]->Name()); + if (!dq_node || dq_node->OpType() != QDQ::DQOpName) { + return std::nullopt; + } + + // DQ must have exactly 1 output edge (to Cast B) and not be a graph output + if (!optimizer_utils::CheckOutputEdges(graph, *dq_node, 1)) { + return std::nullopt; + } + + // Validate DQ the same way as DQMatMulNodeGroupSelector::Check: + // DQ weight type must be 2/4/8-bit int + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zero_point_arg = dq_node->InputDefs().size() == 3 ? dq_node->InputDefs()[2] : nullptr; + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); + + // DQ output type is fp16 (validated by Cast B input check above) + // DQ scales must be float or float16 + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && + dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return std::nullopt; + } + + if (!IsNBitsIntType(dt_weight)) { + return std::nullopt; + } + + // DQ is blockwise quantized along axis 0 + const auto& dq_attrs = dq_node->GetAttributes(); + if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return std::nullopt; + } + + const auto a_iter = dq_attrs.find("block_size"); + if (a_iter == dq_attrs.end()) { + return std::nullopt; + } + + auto block_size = a_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return std::nullopt; + } + + // weight, scale and zero points must be constants + const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); + const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); + const auto* zp_tensor_proto = zero_point_arg ? graph.GetConstantInitializer(zero_point_arg->Name(), true) : nullptr; + + if (!weight_tensor_proto || !scale_tensor_proto) { + return std::nullopt; + } + + if (zero_point_arg && !zp_tensor_proto) { + return std::nullopt; + } + + // weight, scale and zero points must have rank 2 + if (weight_tensor_proto->dims_size() != 2 || scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return std::nullopt; + } + + // check weight, scale and zero points shapes + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return std::nullopt; + } + + // Build selection + NodesToOptimizeIndicesBuilder builder; + builder.input_nodes.push_back(dq_node->Index()); + builder.input_nodes.push_back(cast_b->Index()); + builder.target_node = node.Index(); + + return builder.Build(); +} + bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 79c374b301442..5c10668733785 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -461,6 +461,27 @@ class DQMatMulToMatMulNBitsSelector : public BaseSelector { : BaseSelector(std::make_unique(), compatible_providers) {} }; +// Convert "DQ -> Cast(fp16->fp32) -> MatMul" to "MatMulNBits". +// Handles Cast(fp16->fp32) between DQ and MatMul on input B, and optionally on input A. +// Selection layout: +// input_nodes[0] = DQ node +// input_nodes[1] = Cast on input B (between DQ and MatMul) +// target_node = MatMul +// output_nodes = {} +class DQCastMatMulToMatMulNBitsSelector : public NodeSelector { + public: + explicit DQCastMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) + : compatible_providers_(compatible_providers.begin(), compatible_providers.end()) {} + + DQCastMatMulToMatMulNBitsSelector(DQCastMatMulToMatMulNBitsSelector&& rhs) noexcept + : compatible_providers_(std::move(rhs.compatible_providers_)) {} + + std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override; + + private: + std::vector compatible_providers_; +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index a0b44bbce62f8..c95881b5f3c86 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -4,6 +4,7 @@ #include #include "core/common/span_utils.h" +#include "core/common/float16.h" #include "core/framework/int4.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -343,11 +344,7 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { - // DQ contrib op schema is not updated to support blocked quantization - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + // int8/uint8 are now converted (8-bit support added), so only 16-bit and 32-bit remain as type mismatches RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); @@ -499,6 +496,103 @@ TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); } +// 8-bit DQ -> MatMul conversion to MatMulNBits(bits=8) +// Input1 +// | DQ(int8/uint8) +// \ / +// MatMul +// | DQ(int8/uint8) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulConverted_8bit(const std::vector& input1_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight1_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 0.01f, 0.05f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 0.01f, 0.05f); + if constexpr (use_zp) { + auto* zp1_arg = builder.MakeInitializer(scale1_shape, + static_cast(0), static_cast(2)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, + static_cast(0), static_cast(2)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-4 /*per_sample_tolerance*/, + 1e-4 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_8bit) { + // 8-bit int8/uint8 DQ weights should be fused to MatMulNBits(bits=8) + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + // block_size=32 + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 32, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 32, 0); +} + TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); @@ -511,6 +605,103 @@ TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_Cuda) { RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); } +// Cast-aware DQ->MatMul fusion tests +// Pattern: DQ(int4->fp16) -> Cast(fp16->fp32) -> MatMul(fp32) +// The Cast between DQ and MatMul on input B should be handled by the +// DQCastMatMulToMatMulNBits selector-action pair. +// MatMulNBits always operates in the DQ scale dtype (fp16). +// The action always inserts Cast on input A and Cast on output. +// ORT's redundant cast elimination optimizer cleans up unnecessary casts. +// +// Input1(fp32) DQ(int4->fp16) +// | | +// \ Cast(fp16->fp32) +// \ / +// MatMul(fp32) +// | +// output(fp32) +// +// After optimization: +// Input1(fp32) -> Cast(fp32->fp16) -> MatMulNBits(fp16) -> Cast(fp16->fp32) -> output(fp32) +template +typename std::enable_if || std::is_same_v, void>::type +RunDQCastMatMulConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + // DQ with fp16 scales + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* scale_arg = builder.MakeInitializer(scale_shape, + MLFloat16(0.01f), MLFloat16(0.05f)); + auto* dq_output = builder.MakeIntermediate(); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + // Cast fp16 -> fp32 + auto* cast_output = builder.MakeIntermediate(); + NodeAttributes cast_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("to", + static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)), + cast_attrs); + builder.AddNode("Cast", {dq_output}, {cast_output}, "", &cast_attrs); + + // MatMul + builder.AddNode("MatMul", {input_arg, cast_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + // B-side Cast removed. New Cast(fp32->fp16) on A and Cast(fp16->fp32) on output. + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-2 /*per_sample_tolerance*/, + 1e-2 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQCastMatMulConvertedToMatMulNBits) { + // DQ(int4->fp16) -> Cast(fp16->fp32) -> MatMul should be fused to MatMulNBits + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test From bf7d86173bae2c1afcbd941f4b6d3890c59cdbab Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 10 Mar 2026 22:32:25 +0000 Subject: [PATCH 4/7] comments and 2bit tests --- onnxruntime/core/mlas/lib/q4_dq.cpp | 100 +++++++++++++++++- .../qdq_selector_action_transformer.cc | 2 +- .../selectors_actions/qdq_selectors.cc | 2 +- .../qdq_matmulnbits_transformer_test.cc | 92 ++++++++++++++++ 4 files changed, 193 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index f9019009ae644..5082f5079406a 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -883,8 +883,106 @@ struct BlockwiseQDQQuantizer { } ); } + } else if constexpr (qbits == 2) { + // 2-bit: 4 elements per byte. Element-by-element transpose. + constexpr int32_t kPackSize = 4; + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto packed_src_cols = (columns + kPackSize - 1) / kPackSize; + auto dst_bytes_per_quant_blk = (quant_block_size + kPackSize - 1) / kPackSize; + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + + // Transpose weights: src [rows, ceil(columns/4)] -> dst [columns, k_blocks, ceil(block_size/4)] + // Each thread handles one (row_block, column) pair writing to non-overlapping dst ranges. + MlasTryBatchParallel( + thread_pool, static_cast(row_quant_blk_num * columns), + [&](ptrdiff_t thread_blk_idx) { + auto row_blk = static_cast(thread_blk_idx / columns); + auto col = static_cast(thread_blk_idx % columns); + + auto src_row_start = row_blk * quant_block_size; + auto src_row_end = std::min(src_row_start + quant_block_size, rows); + + auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk; + + // Zero destination bytes for this block + for (int32_t b = 0; b < dst_bytes_per_quant_blk; ++b) { + dst_weights[dst_base + b] = 0; + } + + for (auto r = src_row_start; r < src_row_end; ++r) { + // Extract 2-bit value from source + auto src_byte_idx = r * packed_src_cols + col / kPackSize; + auto src_bit_shift = (col % kPackSize) * 2; + uint8_t val = (src_weights[src_byte_idx] >> src_bit_shift) & 0x3; + + if constexpr (signed_quant) { + val ^= 0x2; // int2[-2,1] -> uint2[0,3] + } + + // Place in destination + auto r_in_blk = r - src_row_start; + auto dst_byte_off = r_in_blk / kPackSize; + auto dst_bit_shift = (r_in_blk % kPackSize) * 2; + dst_weights[dst_base + dst_byte_off] |= (val << dst_bit_shift); + } + + // Zero-pad remaining positions (unsigned equivalent of 0) + if constexpr (signed_quant) { + for (auto r_in_blk = src_row_end - src_row_start; + r_in_blk < quant_block_size; ++r_in_blk) { + auto dst_byte_off = r_in_blk / kPackSize; + auto dst_bit_shift = (r_in_blk % kPackSize) * 2; + dst_weights[dst_base + dst_byte_off] |= (0x2 << dst_bit_shift); + } + } + } + ); + + // Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks] + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col) { + auto src_idx = static_cast(col); + auto dst_idx = static_cast(col) * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + } + ); + + // Transpose zero points: src [k_blocks, ceil(columns/4)] -> dst [columns, ceil(k_blocks/4)] + if (src_zero_points && dst_zero_points) { + auto packed_src_zp_cols = (columns + kPackSize - 1) / kPackSize; + auto zp_dst_bytes_per_col = (row_quant_blk_num + kPackSize - 1) / kPackSize; + + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col_idx) { + auto col = static_cast(col_idx); + auto dst_base = col * zp_dst_bytes_per_col; + + for (int32_t b = 0; b < zp_dst_bytes_per_col; ++b) { + dst_zero_points[dst_base + b] = 0; + } + + for (int32_t blk = 0; blk < row_quant_blk_num; ++blk) { + auto src_byte_idx = blk * packed_src_zp_cols + col / kPackSize; + auto src_bit_shift = (col % kPackSize) * 2; + uint8_t val = (src_zero_points[src_byte_idx] >> src_bit_shift) & 0x3; + + if constexpr (signed_quant) { + val ^= 0x2; + } + + auto dst_byte_off = blk / kPackSize; + auto dst_bit_shift = (blk % kPackSize) * 2; + dst_zero_points[dst_base + dst_byte_off] |= (val << dst_bit_shift); + } + } + ); + } } else { - // Sub-byte types (2-bit, 4-bit): use packing-aware transpose paths. + // 4-bit sub-byte types: use packing-aware transpose paths. // Must avoid multiple thread write to a single byte, which means the starting index // of a thread block must be even. To achieve that, we need to customize the thread // block size based on the parity of columns. diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index af45b63ac1fd7..0b04445692c9b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -297,7 +297,7 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi int64_t qdq_matmulnbits_accuracy_level, concurrency::ThreadPool* intra_op_thread_pool) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. - // DQ's weight is int4/uint4. DQ's scale is float/float16. + // DQ's weight is 2/4/8-bit int (int2/uint2, int4/uint4, int8/uint8). DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. const std::string action_name{"DQMatMulToMatMulNBits"}; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index fd6b525334046..485c006da4c0f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -582,7 +582,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod return false; } - // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 + // DQ weight/zero points types are 2/4/8-bit int, scales/output types are float or float16 const auto* weight_arg = dq_nodes[0]->InputDefs()[0]; const auto* scale_arg = dq_nodes[0]->InputDefs()[1]; const auto* zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index c95881b5f3c86..c0cd40ad95ad4 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -582,6 +582,98 @@ void RunDQMatMulConverted_8bit(const std::vector& input1_shape, add_session_options_fn); } +// 2-bit DQ -> MatMul conversion to MatMulNBits(bits=2) +// Input1 +// | DQ(int2/uint2) +// \ / +// MatMul +// | DQ(int2/uint2) +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulConverted_2bit(const std::vector& input1_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight1_shape, + T(T::min_val, T::min_val, T::min_val, T::min_val), + T(T::max_val, T::max_val, T::max_val, T::max_val)); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, + T(T::min_val, T::min_val, T::min_val, T::min_val), + T(T::max_val, T::max_val, T::max_val, T::max_val)); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 8.0f, 12.0f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp1_arg = builder.MakeInitializer(scale1_shape, T(0, 0, 0, 0), T(1, 1, 1, 1)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, T(0, 0, 0, 0), T(1, 1, 1, 1)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 25 /*opset_version*/, + 1e-4 /*per_sample_tolerance*/, + 1e-4 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_2bit) { + // 2-bit int2/uint2 DQ weights should be fused to MatMulNBits(bits=2) + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); +} + TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_8bit) { // 8-bit int8/uint8 DQ weights should be fused to MatMulNBits(bits=8) RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); From 56b5bbc332190d879cbcf2f8c13109b4411eff37 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 10 Mar 2026 22:36:31 +0000 Subject: [PATCH 5/7] replace C++20 explicit lambda template parameter syntax --- .../selectors_actions/qdq_actions.cc | 66 ++++++++----------- 1 file changed, 26 insertions(+), 40 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 1ca38530003d1..15b6b45a5bcbe 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -416,39 +416,32 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, auto transpose = [&](auto* scale_data, auto* scale_dst_data) { using ScaleType = std::remove_pointer_t; bool is_signed = IsDQWeightSigned(dt_weight); - auto call = [&]() { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_data, - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst_data, - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); - }; + const uint8_t* src_w = weight_src.DataAsByteSpan().data(); + const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; + uint8_t* dst_w = weight_dst.MutableData(); + uint8_t* dst_zp = zp_dst ? zp_dst->MutableData() : nullptr; + int K_int = static_cast(K); + int N_int = static_cast(N); + int bs_int = static_cast(block_size); // Dispatch based on bits and signedness. Template parameters must be compile-time constants. if (bits == 2) { if (is_signed) { - call.template operator()<2, true>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } else { - call.template operator()<2, false>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } } else if (bits == 4) { if (is_signed) { - call.template operator()<4, true>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } else { - call.template operator()<4, false>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } } else { if (is_signed) { - call.template operator()<8, true>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } else { - call.template operator()<8, false>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } } }; @@ -558,38 +551,31 @@ Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& auto transpose = [&](auto* scale_data, auto* scale_dst_data) { using ScaleType = std::remove_pointer_t; bool is_signed = IsDQWeightSigned(dt_weight); - auto call = [&]() { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_data, - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst_data, - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); - }; + const uint8_t* src_w = weight_src.DataAsByteSpan().data(); + const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; + uint8_t* dst_w = weight_dst.MutableData(); + uint8_t* dst_zp = zp_dst ? zp_dst->MutableData() : nullptr; + int K_int = static_cast(K); + int N_int = static_cast(N); + int bs_int = static_cast(block_size); if (bits == 2) { if (is_signed) { - call.template operator()<2, true>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } else { - call.template operator()<2, false>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } } else if (bits == 4) { if (is_signed) { - call.template operator()<4, true>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } else { - call.template operator()<4, false>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } } else { if (is_signed) { - call.template operator()<8, true>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } else { - call.template operator()<8, false>(); + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); } } }; From c53889b522d8c5f80b7852bbe4c8a92cb2855eb8 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 10 Mar 2026 23:04:56 +0000 Subject: [PATCH 6/7] deduplicate code --- .../selectors_actions/qdq_actions.cc | 362 +++++++----------- 1 file changed, 142 insertions(+), 220 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 15b6b45a5bcbe..da2e8fc37382a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -41,6 +41,126 @@ bool IsDQWeightSigned(int32_t dt_weight) { dt_weight == TensorProto::INT4 || dt_weight == TensorProto::INT8; } + +// Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits. +// Used by both DQMatMulToMatMulNBitsAction and DQCastMatMulToMatMulNBitsAction. +struct TransposedQuantizedTensors { + Tensor weight; + Tensor scale; + std::optional zero_point; + + ONNX_NAMESPACE::TensorProto weight_proto; + ONNX_NAMESPACE::TensorProto scale_proto; + std::optional zero_point_proto; +}; + +// Transpose DQ weight/scale/zp tensors from column-wise layout to MatMulNBits layout via MLAS. +// default_zp_name_prefix: prefix for auto-generated zero-point name when unsigned type has no explicit zp. +Status TransposeDQWeightsForMatMulNBits( + Graph& graph, + const Node& dq_node, + const std::string& default_zp_name_prefix, + concurrency::ThreadPool* intra_op_thread_pool, + TransposedQuantizedTensors& result) { + const auto* weight_arg = dq_node.InputDefs()[0]; + const auto* scale_arg = dq_node.InputDefs()[1]; + const auto* zp_arg = dq_node.InputDefs().size() > 2 ? dq_node.InputDefs()[2] : nullptr; + const auto& attrs = dq_node.GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), + "Missing required weight: ", weight_arg->Name(), " for node: ", dq_node.Name()); + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), + "Missing required scale: ", scale_arg->Name(), " for node: ", dq_node.Name()); + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = attrs.at("block_size").i(); + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + auto bits = DQWeightBits(dt_weight); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size * bits + 7) / 8; + + Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + + std::optional zp_src; + auto cpu_allocator = CPUAllocator::DefaultInstance(); + + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); + result.weight = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_size = (TensorShape{N, quant_num}).Size(); + result.scale = Tensor(scale_type, TensorShape{scale_size}, cpu_allocator); + + std::string zp_dst_name; + auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); + + if (zp_tensor_proto) { + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); + result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + } else if (!IsDQWeightSigned(dt_weight)) { + zp_dst_name = graph.GenerateNodeArgName(default_zp_name_prefix + "_zero_point_T"); + result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + memset(result.zero_point->MutableDataRaw(), 0, result.zero_point->SizeInBytes()); + } + + // Dispatch MLAS transpose based on scale type, bits, and signedness. + auto transpose = [&](auto* scale_data, auto* scale_dst_data) { + using ScaleType = std::remove_pointer_t; + bool is_signed = IsDQWeightSigned(dt_weight); + const uint8_t* src_w = weight_src.DataAsByteSpan().data(); + const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; + uint8_t* dst_w = result.weight.MutableData(); + uint8_t* dst_zp = result.zero_point ? result.zero_point->MutableData() : nullptr; + int K_int = static_cast(K); + int N_int = static_cast(N); + int bs_int = static_cast(block_size); + + if (bits == 2) { + if (is_signed) { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } else { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } + } else if (bits == 4) { + if (is_signed) { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } else { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } + } else { + if (is_signed) { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } else { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } + } + }; + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + transpose(scale_src.data(), result.scale.MutableData()); + } else { + transpose(scale_src.data(), result.scale.MutableData()); + } + + result.weight_proto = utils::TensorToTensorProto(result.weight, weight_dst_name, true); + result.scale_proto = utils::TensorToTensorProto(result.scale, scale_dst_name, true); + if (result.zero_point) { + result.zero_point_proto.emplace(utils::TensorToTensorProto(*result.zero_point, zp_dst_name, true)); + } + + return Status::OK(); +} } // namespace namespace { @@ -346,129 +466,20 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { const auto* dq_node = selected_nodes.Input(0); - const auto* weight_arg = dq_node->InputDefs()[0]; - const auto* scale_arg = dq_node->InputDefs()[1]; - const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; - const auto& attrs = dq_node->GetAttributes(); - const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; - ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), - "Missing required weight: ", weight_arg->Name(), " for node: ", dq_node->Name()); - - const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; - ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), - "Missing required scale: ", scale_arg->Name(), " for node: ", dq_node->Name()); - const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; - if (zp_arg) { - // zero point is optional, one can have a NodeArg for a missing optional - // if the name is an empty string, and the below would not return ptr to a proto. - graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); - } - - auto K = weight_arg->Shape()->dim(0).dim_value(); - auto N = weight_arg->Shape()->dim(1).dim_value(); - auto block_size = attrs.at("block_size").i(); - int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); - auto bits = DQWeightBits(dt_weight); - auto quant_num = (K + block_size - 1) / block_size; - auto blob_bytes = (block_size * bits + 7) / 8; - - // Unfortunately iterating the source data is complicated, the data maybe in - // external file, a raw buffer, or a repeated field depending on the data - // type. UnpackTensor() already contains some of these logic and is closest - // to what we need. But it does not handle external data. - Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); - Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); - auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); - auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); - - std::optional zp_src; - auto cpu_allocator = CPUAllocator::DefaultInstance(); - auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); - auto weight_dst = Tensor(uint8_type, - TensorShape{N, quant_num, blob_bytes}, - cpu_allocator); - auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); - auto scale_size = (TensorShape{N, quant_num}).Size(); - auto scale_dst = Tensor(scale_type, - TensorShape{scale_size}, - cpu_allocator); - std::string zp_dst_name; - std::optional zp_dst; - auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); // packed zp bytes per column - - if (zp_tensor_proto) { - zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); - zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); - zp_dst = Tensor(uint8_type, - TensorShape{zp_size}, - cpu_allocator); - } else if (!IsDQWeightSigned(dt_weight)) { - // Unsigned quant types without explicit zero points need a default zero-point buffer of 0. - zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); - zp_dst = Tensor(uint8_type, - TensorShape{zp_size}, - cpu_allocator); - memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); - } - - // Helper lambda to dispatch the MLAS transpose for a given scale type. - auto transpose = [&](auto* scale_data, auto* scale_dst_data) { - using ScaleType = std::remove_pointer_t; - bool is_signed = IsDQWeightSigned(dt_weight); - const uint8_t* src_w = weight_src.DataAsByteSpan().data(); - const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; - uint8_t* dst_w = weight_dst.MutableData(); - uint8_t* dst_zp = zp_dst ? zp_dst->MutableData() : nullptr; - int K_int = static_cast(K); - int N_int = static_cast(N); - int bs_int = static_cast(block_size); - - // Dispatch based on bits and signedness. Template parameters must be compile-time constants. - if (bits == 2) { - if (is_signed) { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } else { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } - } else if (bits == 4) { - if (is_signed) { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } else { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } - } else { - if (is_signed) { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } else { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } - } - }; - - if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - transpose(scale_src.data(), scale_dst.MutableData()); - } else { - transpose(scale_src.data(), scale_dst.MutableData()); - } - - auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); - auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); - std::optional zp_T_tp; - - if (zp_dst) { - zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); - } + TransposedQuantizedTensors transposed; + ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( + graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, transposed)); auto& input_defs = replacement_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_T_tp, std::move(weight_dst))); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); replacement_node.MutableInputArgsCount().push_back(1); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_T_tp, std::move(scale_dst))); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale))); replacement_node.MutableInputArgsCount().push_back(1); - if (zp_T_tp) { - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_T_tp.value(), std::move(*zp_dst))); + if (transposed.zero_point_proto) { + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, *transposed.zero_point_proto, std::move(*transposed.zero_point))); replacement_node.MutableInputArgsCount().push_back(1); } @@ -492,99 +503,10 @@ Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& auto* cast_b_node = selected_nodes.Input(1); auto& matmul_node = selected_nodes.Target(); - // --- Get DQ weight/scale/zp info --- - const auto* weight_arg = dq_node->InputDefs()[0]; - const auto* scale_arg = dq_node->InputDefs()[1]; - const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; - const auto& dq_attrs = dq_node->GetAttributes(); - - const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; - ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), - "Missing required weight: ", weight_arg->Name()); - const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; - ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), - "Missing required scale: ", scale_arg->Name()); - const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; - if (zp_arg) { - graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); - } - - auto K = weight_arg->Shape()->dim(0).dim_value(); - auto N = weight_arg->Shape()->dim(1).dim_value(); - auto block_size = dq_attrs.at("block_size").i(); - int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); - auto bits = DQWeightBits(dt_weight); - auto quant_num = (K + block_size - 1) / block_size; - auto blob_bytes = (block_size * bits + 7) / 8; - - // --- Transpose weights/scales/zp via MLAS --- - Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); - Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); - - auto cpu_allocator = CPUAllocator::DefaultInstance(); - auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); - - auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); - auto weight_dst = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); - - auto orig_scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); - auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); - auto scale_size = (TensorShape{N, quant_num}).Size(); - auto scale_dst = Tensor(orig_scale_type, TensorShape{scale_size}, cpu_allocator); - - std::string zp_dst_name; - std::optional zp_dst; - std::optional zp_src; - auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); - - if (zp_tensor_proto) { - zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); - zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); - zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); - } else if (!IsDQWeightSigned(dt_weight)) { - zp_dst_name = graph.GenerateNodeArgName("fused_DQ_Cast_MatMul_zero_point_T"); - zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); - memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); - } - - // MLAS transpose dispatch - auto transpose = [&](auto* scale_data, auto* scale_dst_data) { - using ScaleType = std::remove_pointer_t; - bool is_signed = IsDQWeightSigned(dt_weight); - const uint8_t* src_w = weight_src.DataAsByteSpan().data(); - const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; - uint8_t* dst_w = weight_dst.MutableData(); - uint8_t* dst_zp = zp_dst ? zp_dst->MutableData() : nullptr; - int K_int = static_cast(K); - int N_int = static_cast(N); - int bs_int = static_cast(block_size); - - if (bits == 2) { - if (is_signed) { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } else { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } - } else if (bits == 4) { - if (is_signed) { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } else { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } - } else { - if (is_signed) { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } else { - MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool_); - } - } - }; - - if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - transpose(scale_src.data(), scale_dst.MutableData()); - } else { - transpose(scale_src.data(), scale_dst.MutableData()); - } + // --- Transpose DQ weights/scales/zp via shared helper --- + TransposedQuantizedTensors transposed; + ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( + graph, *dq_node, "fused_DQ_Cast_MatMul", intra_op_thread_pool_, transposed)); // MatMulNBits operates in the DQ scale dtype. // Always insert Cast on input A (to DQ dtype) and Cast on output (DQ dtype to MatMul output dtype). @@ -595,13 +517,13 @@ Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& // Determine MatMul output element type (e.g., fp32) int32_t matmul_output_dtype = matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - // Prepare tensor protos for initializers - auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); - auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); - std::optional zp_T_tp; - if (zp_dst) { - zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); - } + const auto& dq_attrs = dq_node->GetAttributes(); + const auto* weight_arg = dq_node->InputDefs()[0]; + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = dq_attrs.at("block_size").i(); + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + auto bits = DQWeightBits(dt_weight); // --- Create fp16 NodeArg for MatMulNBits input A --- NodeArg* matmul_input_a = matmul_node.MutableInputDefs()[0]; @@ -646,14 +568,14 @@ Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& // Add transposed weight, scale, zp to inputs auto& input_defs = new_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_T_tp, std::move(weight_dst))); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); new_node.MutableInputArgsCount().push_back(1); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_T_tp, std::move(scale_dst))); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale))); new_node.MutableInputArgsCount().push_back(1); - if (zp_T_tp) { - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_T_tp.value(), std::move(*zp_dst))); + if (transposed.zero_point_proto) { + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, *transposed.zero_point_proto, std::move(*transposed.zero_point))); new_node.MutableInputArgsCount().push_back(1); } From d1a02f5b2b83986703a20bf2fa4a25bd6524444e Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 10 Mar 2026 23:21:03 +0000 Subject: [PATCH 7/7] more dedup --- .../selectors_actions/qdq_selectors.cc | 126 ++++++------------ 1 file changed, 38 insertions(+), 88 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 485c006da4c0f..c39dfeb082e35 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -558,36 +558,17 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& } } -bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, - const Node* redundant_clip_node, const std::vector& dq_nodes, - const std::vector& q_nodes) const { - if (redundant_clip_node) { - return false; - } - - // Should not have any Q nodes - if (!q_nodes.empty()) { - return false; - } - - const auto& graph = graph_viewer.GetGraph(); - - // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output - if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { - return false; - } - - // DQ must be MatMul's the second input - if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { - return false; - } - - // DQ weight/zero points types are 2/4/8-bit int, scales/output types are float or float16 - const auto* weight_arg = dq_nodes[0]->InputDefs()[0]; - const auto* scale_arg = dq_nodes[0]->InputDefs()[1]; - const auto* zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; +// Validate that a DQ node has the correct structure for MatMulNBits fusion: +// - weight type is 2/4/8-bit int, scale type is float or float16 +// - blockwise quantization along axis 0, block_size is power-of-2 and >= 16 +// - weight/scale/zp are constant initializers with rank 2 and consistent shapes +static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq_node) { + const auto* weight_arg = dq_node.InputDefs()[0]; + const auto* scale_arg = dq_node.InputDefs()[1]; + const auto* zero_point_arg = dq_node.InputDefs().size() == 3 ? dq_node.InputDefs()[2] : nullptr; int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { return false; @@ -598,7 +579,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod } // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 - const auto& dq_attrs = dq_nodes[0]->GetAttributes(); + const auto& dq_attrs = dq_node.GetAttributes(); if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { return false; } @@ -643,6 +624,33 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod return true; } +bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* redundant_clip_node, const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (redundant_clip_node) { + return false; + } + + // Should not have any Q nodes + if (!q_nodes.empty()) { + return false; + } + + const auto& graph = graph_viewer.GetGraph(); + + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + return false; + } + + // DQ must be MatMul's the second input + if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + return false; + } + + return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]); +} + std::optional DQCastMatMulToMatMulNBitsSelector::Select(const GraphViewer& graph_viewer, const Node& node) const { // Check EP compatibility @@ -699,65 +707,7 @@ DQCastMatMulToMatMulNBitsSelector::Select(const GraphViewer& graph_viewer, const return std::nullopt; } - // Validate DQ the same way as DQMatMulNodeGroupSelector::Check: - // DQ weight type must be 2/4/8-bit int - const auto* weight_arg = dq_node->InputDefs()[0]; - const auto* scale_arg = dq_node->InputDefs()[1]; - const auto* zero_point_arg = dq_node->InputDefs().size() == 3 ? dq_node->InputDefs()[2] : nullptr; - int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); - - // DQ output type is fp16 (validated by Cast B input check above) - // DQ scales must be float or float16 - if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && - dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { - return std::nullopt; - } - - if (!IsNBitsIntType(dt_weight)) { - return std::nullopt; - } - - // DQ is blockwise quantized along axis 0 - const auto& dq_attrs = dq_node->GetAttributes(); - if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { - return std::nullopt; - } - - const auto a_iter = dq_attrs.find("block_size"); - if (a_iter == dq_attrs.end()) { - return std::nullopt; - } - - auto block_size = a_iter->second.i(); - if (block_size < 16 || ((block_size - 1) & block_size)) { - return std::nullopt; - } - - // weight, scale and zero points must be constants - const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); - const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); - const auto* zp_tensor_proto = zero_point_arg ? graph.GetConstantInitializer(zero_point_arg->Name(), true) : nullptr; - - if (!weight_tensor_proto || !scale_tensor_proto) { - return std::nullopt; - } - - if (zero_point_arg && !zp_tensor_proto) { - return std::nullopt; - } - - // weight, scale and zero points must have rank 2 - if (weight_tensor_proto->dims_size() != 2 || scale_tensor_proto->dims_size() != 2 || - (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { - return std::nullopt; - } - - // check weight, scale and zero points shapes - if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || - weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || - (zp_tensor_proto && (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || - zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + if (!ValidateBlockwiseDQForMatMulNBits(graph, *dq_node)) { return std::nullopt; }