diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index fbbf4005ae4a5..5082f5079406a 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -663,6 +663,9 @@ struct BlockwiseQDQQuantizer { return (val >> (idx << 1)) & 0x3; } else if constexpr (qbits == 4) { return (val >> (idx << 2)) & 0xF; + } else if constexpr (qbits == 8) { + (void)idx; + return val; } } @@ -674,6 +677,10 @@ struct BlockwiseQDQQuantizer { } else if constexpr (qbits == 4) { auto shift = idx << 2; return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } else if constexpr (qbits == 8) { + (void)idx; + (void)dst; + return val; } } @@ -813,21 +820,185 @@ struct BlockwiseQDQQuantizer { src_zero_points || signed_quant || dst_zero_points, "Unsigned quant types without zero points must allocate zero points with value 0." ); - // Must avoid multiple thread write to a single byte, which means the starting index - // of a thread block must be even. To achieve that, we need to customize the thread - // block size based on the parity of columns. - if (columns & 1) { - TransposeColumnWiseQuantizedPackUnaligned( - src_weights, src_scales, src_zero_points, - dst_weights, dst_scales, dst_zero_points, - rows, columns, quant_block_size, thread_pool + + if constexpr (qbits == 8) { + // 8-bit: each element is one byte, no sub-byte packing needed. + // Simple byte-level transpose from [rows, columns] to [columns, k_blocks, block_size]. + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto dst_bytes_per_quant_blk = quant_block_size; // 8 bits = 1 byte per element + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + + // Transpose weights: src [rows, columns] -> dst [columns, k_blocks, block_size] + MlasTryBatchParallel( + thread_pool, static_cast(row_quant_blk_num * columns), + [&](ptrdiff_t thread_blk_idx) { + auto row_blk = static_cast(thread_blk_idx / columns); + auto col = static_cast(thread_blk_idx % columns); + + auto src_row_start = row_blk * quant_block_size; + auto src_row_end = std::min(src_row_start + quant_block_size, rows); + + auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk; + for (auto r = src_row_start; r < src_row_end; ++r) { + auto src_val = src_weights[r * columns + col]; + if constexpr (signed_quant) { + src_val ^= 0x80; // INT8 -> UINT8: add 128 + } + dst_weights[dst_base + (r - src_row_start)] = src_val; + } + // Zero-pad remaining bytes in the last block if rows % block_size != 0 + for (auto r = src_row_end - src_row_start; r < quant_block_size; ++r) { + dst_weights[dst_base + r] = signed_quant ? 0x80 : 0; + } + } ); - } else { - TransposeColumnWiseQuantizedPackAligned( - src_weights, src_scales, src_zero_points, - dst_weights, dst_scales, dst_zero_points, - rows, columns, quant_block_size, thread_pool + + // Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks] + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col) { + auto src_idx = static_cast(col); + auto dst_idx = static_cast(col) * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + } ); + + // Transpose zero points: src [k_blocks, columns] -> dst [columns, k_blocks] + // For 8-bit, zero points are byte-aligned (1 byte each), no packing needed. + if (src_zero_points && dst_zero_points) { + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col) { + auto src_idx = static_cast(col); + auto dst_idx = static_cast(col) * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + auto zp = src_zero_points[src_idx]; + if constexpr (signed_quant) { + zp ^= 0x80; // INT8 -> UINT8 + } + dst_zero_points[dst_idx] = zp; + } + } + ); + } + } else if constexpr (qbits == 2) { + // 2-bit: 4 elements per byte. Element-by-element transpose. + constexpr int32_t kPackSize = 4; + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto packed_src_cols = (columns + kPackSize - 1) / kPackSize; + auto dst_bytes_per_quant_blk = (quant_block_size + kPackSize - 1) / kPackSize; + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + + // Transpose weights: src [rows, ceil(columns/4)] -> dst [columns, k_blocks, ceil(block_size/4)] + // Each thread handles one (row_block, column) pair writing to non-overlapping dst ranges. + MlasTryBatchParallel( + thread_pool, static_cast(row_quant_blk_num * columns), + [&](ptrdiff_t thread_blk_idx) { + auto row_blk = static_cast(thread_blk_idx / columns); + auto col = static_cast(thread_blk_idx % columns); + + auto src_row_start = row_blk * quant_block_size; + auto src_row_end = std::min(src_row_start + quant_block_size, rows); + + auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk; + + // Zero destination bytes for this block + for (int32_t b = 0; b < dst_bytes_per_quant_blk; ++b) { + dst_weights[dst_base + b] = 0; + } + + for (auto r = src_row_start; r < src_row_end; ++r) { + // Extract 2-bit value from source + auto src_byte_idx = r * packed_src_cols + col / kPackSize; + auto src_bit_shift = (col % kPackSize) * 2; + uint8_t val = (src_weights[src_byte_idx] >> src_bit_shift) & 0x3; + + if constexpr (signed_quant) { + val ^= 0x2; // int2[-2,1] -> uint2[0,3] + } + + // Place in destination + auto r_in_blk = r - src_row_start; + auto dst_byte_off = r_in_blk / kPackSize; + auto dst_bit_shift = (r_in_blk % kPackSize) * 2; + dst_weights[dst_base + dst_byte_off] |= (val << dst_bit_shift); + } + + // Zero-pad remaining positions (unsigned equivalent of 0) + if constexpr (signed_quant) { + for (auto r_in_blk = src_row_end - src_row_start; + r_in_blk < quant_block_size; ++r_in_blk) { + auto dst_byte_off = r_in_blk / kPackSize; + auto dst_bit_shift = (r_in_blk % kPackSize) * 2; + dst_weights[dst_base + dst_byte_off] |= (0x2 << dst_bit_shift); + } + } + } + ); + + // Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks] + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col) { + auto src_idx = static_cast(col); + auto dst_idx = static_cast(col) * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + } + ); + + // Transpose zero points: src [k_blocks, ceil(columns/4)] -> dst [columns, ceil(k_blocks/4)] + if (src_zero_points && dst_zero_points) { + auto packed_src_zp_cols = (columns + kPackSize - 1) / kPackSize; + auto zp_dst_bytes_per_col = (row_quant_blk_num + kPackSize - 1) / kPackSize; + + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t col_idx) { + auto col = static_cast(col_idx); + auto dst_base = col * zp_dst_bytes_per_col; + + for (int32_t b = 0; b < zp_dst_bytes_per_col; ++b) { + dst_zero_points[dst_base + b] = 0; + } + + for (int32_t blk = 0; blk < row_quant_blk_num; ++blk) { + auto src_byte_idx = blk * packed_src_zp_cols + col / kPackSize; + auto src_bit_shift = (col % kPackSize) * 2; + uint8_t val = (src_zero_points[src_byte_idx] >> src_bit_shift) & 0x3; + + if constexpr (signed_quant) { + val ^= 0x2; + } + + auto dst_byte_off = blk / kPackSize; + auto dst_bit_shift = (blk % kPackSize) * 2; + dst_zero_points[dst_base + dst_byte_off] |= (val << dst_bit_shift); + } + } + ); + } + } else { + // 4-bit sub-byte types: use packing-aware transpose paths. + // Must avoid multiple thread write to a single byte, which means the starting index + // of a thread block must be even. To achieve that, we need to customize the thread + // block size based on the parity of columns. + if (columns & 1) { + TransposeColumnWiseQuantizedPackUnaligned( + src_weights, src_scales, src_zero_points, + dst_weights, dst_scales, dst_zero_points, + rows, columns, quant_block_size, thread_pool + ); + } else { + TransposeColumnWiseQuantizedPackAligned( + src_weights, src_scales, src_zero_points, + dst_weights, dst_scales, dst_zero_points, + rows, columns, quant_block_size, thread_pool + ); + } } } @@ -2184,3 +2355,93 @@ MlasQDQTransposeBlockwiseQuantized( int quant_block_size, MLAS_THREADPOOL* thread_pool ); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index dddf80252f727..da2e8fc37382a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -8,12 +8,161 @@ #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" +#include "core/graph/graph_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace QDQ { +namespace { +// Derive MatMulNBits 'bits' attribute from the DQ weight element type. +int64_t DQWeightBits(int32_t dt_weight) { + using TensorProto = ONNX_NAMESPACE::TensorProto; + switch (dt_weight) { + case TensorProto::INT2: + case TensorProto::UINT2: + return 2; + case TensorProto::INT4: + case TensorProto::UINT4: + return 4; + case TensorProto::INT8: + case TensorProto::UINT8: + return 8; + default: + ORT_THROW("Unsupported DQ weight type for MatMulNBits fusion: ", dt_weight); + } +} + +// Whether the DQ weight type is signed (requires zero-point offset conversion). +bool IsDQWeightSigned(int32_t dt_weight) { + using TensorProto = ONNX_NAMESPACE::TensorProto; + return dt_weight == TensorProto::INT2 || + dt_weight == TensorProto::INT4 || + dt_weight == TensorProto::INT8; +} + +// Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits. +// Used by both DQMatMulToMatMulNBitsAction and DQCastMatMulToMatMulNBitsAction. +struct TransposedQuantizedTensors { + Tensor weight; + Tensor scale; + std::optional zero_point; + + ONNX_NAMESPACE::TensorProto weight_proto; + ONNX_NAMESPACE::TensorProto scale_proto; + std::optional zero_point_proto; +}; + +// Transpose DQ weight/scale/zp tensors from column-wise layout to MatMulNBits layout via MLAS. +// default_zp_name_prefix: prefix for auto-generated zero-point name when unsigned type has no explicit zp. +Status TransposeDQWeightsForMatMulNBits( + Graph& graph, + const Node& dq_node, + const std::string& default_zp_name_prefix, + concurrency::ThreadPool* intra_op_thread_pool, + TransposedQuantizedTensors& result) { + const auto* weight_arg = dq_node.InputDefs()[0]; + const auto* scale_arg = dq_node.InputDefs()[1]; + const auto* zp_arg = dq_node.InputDefs().size() > 2 ? dq_node.InputDefs()[2] : nullptr; + const auto& attrs = dq_node.GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), + "Missing required weight: ", weight_arg->Name(), " for node: ", dq_node.Name()); + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), + "Missing required scale: ", scale_arg->Name(), " for node: ", dq_node.Name()); + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = attrs.at("block_size").i(); + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + auto bits = DQWeightBits(dt_weight); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size * bits + 7) / 8; + + Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + + std::optional zp_src; + auto cpu_allocator = CPUAllocator::DefaultInstance(); + + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); + result.weight = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_size = (TensorShape{N, quant_num}).Size(); + result.scale = Tensor(scale_type, TensorShape{scale_size}, cpu_allocator); + + std::string zp_dst_name; + auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size(); + + if (zp_tensor_proto) { + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); + result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + } else if (!IsDQWeightSigned(dt_weight)) { + zp_dst_name = graph.GenerateNodeArgName(default_zp_name_prefix + "_zero_point_T"); + result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + memset(result.zero_point->MutableDataRaw(), 0, result.zero_point->SizeInBytes()); + } + + // Dispatch MLAS transpose based on scale type, bits, and signedness. + auto transpose = [&](auto* scale_data, auto* scale_dst_data) { + using ScaleType = std::remove_pointer_t; + bool is_signed = IsDQWeightSigned(dt_weight); + const uint8_t* src_w = weight_src.DataAsByteSpan().data(); + const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr; + uint8_t* dst_w = result.weight.MutableData(); + uint8_t* dst_zp = result.zero_point ? result.zero_point->MutableData() : nullptr; + int K_int = static_cast(K); + int N_int = static_cast(N); + int bs_int = static_cast(block_size); + + if (bits == 2) { + if (is_signed) { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } else { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } + } else if (bits == 4) { + if (is_signed) { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } else { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } + } else { + if (is_signed) { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } else { + MlasQDQTransposeBlockwiseQuantized(src_w, scale_data, src_zp, dst_w, scale_dst_data, dst_zp, true, K_int, N_int, bs_int, intra_op_thread_pool); + } + } + }; + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + transpose(scale_src.data(), result.scale.MutableData()); + } else { + transpose(scale_src.data(), result.scale.MutableData()); + } + + result.weight_proto = utils::TensorToTensorProto(result.weight, weight_dst_name, true); + result.scale_proto = utils::TensorToTensorProto(result.scale, scale_dst_name, true); + if (result.zero_point) { + result.zero_point_proto.emplace(utils::TensorToTensorProto(*result.zero_point, zp_dst_name, true)); + } + + return Status::OK(); +} +} // namespace + namespace { using NTO = NodesToOptimize; @@ -306,8 +455,8 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); - // currently only 4bits is supported. In the future, derive bits from DQ's weight type. - utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); + int32_t dt_weight = dq_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + utils::SetNodeAttribute(utils::MakeAttribute("bits", DQWeightBits(dt_weight)), extra_attributes); utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); return extra_attributes; @@ -317,147 +466,162 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { const auto* dq_node = selected_nodes.Input(0); - const auto* weight_arg = dq_node->InputDefs()[0]; - const auto* scale_arg = dq_node->InputDefs()[1]; - const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; - const auto& attrs = dq_node->GetAttributes(); - const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; - ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), - "Missing required weight: ", weight_arg->Name(), " for node: ", dq_node->Name()); - - const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; - ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), - "Missing required scale: ", scale_arg->Name(), " for node: ", dq_node->Name()); - const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; - if (zp_arg) { - // zero point is optional, one can have a NodeArg for a missing optional - // if the name is an empty string, and the below would not return ptr to a proto. - graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); - } + TransposedQuantizedTensors transposed; + ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( + graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, transposed)); - auto K = weight_arg->Shape()->dim(0).dim_value(); - auto N = weight_arg->Shape()->dim(1).dim_value(); - auto block_size = attrs.at("block_size").i(); - auto quant_num = (K + block_size - 1) / block_size; - auto blob_bytes = (block_size + 1) / 2; - - // Unfortunately iterating the source data is complicated, the data maybe in - // external file, a raw buffer, or a repeated field depending on the data - // type. UnpackTensor() already contains some of these logic and is closest - // to what we need. But it does not handle external data. - Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); - Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); - auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); - auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + auto& input_defs = replacement_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); + replacement_node.MutableInputArgsCount().push_back(1); - std::optional zp_src; - auto cpu_allocator = CPUAllocator::DefaultInstance(); - auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); - auto weight_dst = Tensor(uint8_type, - TensorShape{N, quant_num, blob_bytes}, - cpu_allocator); - auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); - auto scale_size = (TensorShape{N, quant_num}).Size(); - auto scale_dst = Tensor(scale_type, - TensorShape{scale_size}, - cpu_allocator); - std::string zp_dst_name; - std::optional zp_dst; - auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale))); + replacement_node.MutableInputArgsCount().push_back(1); - if (zp_tensor_proto) { - zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); - zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); - zp_dst = Tensor(uint8_type, - TensorShape{zp_size}, - cpu_allocator); - } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); - zp_dst = Tensor(uint8_type, - TensorShape{zp_size}, - cpu_allocator); - memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); + if (transposed.zero_point_proto) { + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, *transposed.zero_point_proto, std::move(*transposed.zero_point))); + replacement_node.MutableInputArgsCount().push_back(1); } - if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst.MutableData(), - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); - } else { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst.MutableData(), - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); - } - } else { - if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst.MutableData(), - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); + return Status::OK(); +} - } else { - MlasQDQTransposeBlockwiseQuantized( - weight_src.DataAsByteSpan().data(), - scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.MutableData(), - scale_dst.MutableData(), - zp_dst ? zp_dst->MutableData() : nullptr, - true, - static_cast(K), - static_cast(N), - static_cast(block_size), - intra_op_thread_pool_); - } - } +DQCastMatMulToMatMulNBitsAction::DQCastMatMulToMatMulNBitsAction( + int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) + : accuracy_level_{accuracy_level}, + intra_op_thread_pool_{intra_op_thread_pool} { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); +} - auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); - auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); - std::optional zp_T_tp; +Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { + // Selected nodes layout (from DQCastMatMulToMatMulNBitsSelector): + // Input(0) = DQ node + // Input(1) = Cast on input B (between DQ and MatMul) + // Target() = MatMul node + auto* dq_node = selected_nodes.Input(0); + auto* cast_b_node = selected_nodes.Input(1); + auto& matmul_node = selected_nodes.Target(); + + // --- Transpose DQ weights/scales/zp via shared helper --- + TransposedQuantizedTensors transposed; + ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( + graph, *dq_node, "fused_DQ_Cast_MatMul", intra_op_thread_pool_, transposed)); + + // MatMulNBits operates in the DQ scale dtype. + // Always insert Cast on input A (to DQ dtype) and Cast on output (DQ dtype to MatMul output dtype). + // ORT's redundant cast elimination optimizer will clean up unnecessary casts later. + + // Determine DQ output element type (e.g., fp16) + int32_t dq_output_dtype = cast_b_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + // Determine MatMul output element type (e.g., fp32) + int32_t matmul_output_dtype = matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + const auto& dq_attrs = dq_node->GetAttributes(); + const auto* weight_arg = dq_node->InputDefs()[0]; + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = dq_attrs.at("block_size").i(); + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + auto bits = DQWeightBits(dt_weight); + + // --- Create fp16 NodeArg for MatMulNBits input A --- + NodeArg* matmul_input_a = matmul_node.MutableInputDefs()[0]; + ONNX_NAMESPACE::TypeProto input_a_fp16_type; + input_a_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype); + if (matmul_input_a->Shape()) { + *input_a_fp16_type.mutable_tensor_type()->mutable_shape() = + matmul_input_a->TypeAsProto()->tensor_type().shape(); + } + auto cast_a_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_input_a_cast"); + NodeArg* input_a_arg = &graph.GetOrCreateNodeArg(cast_a_out_name, &input_a_fp16_type); + + // --- Create fp16 NodeArg for MatMulNBits output --- + ONNX_NAMESPACE::TypeProto output_fp16_type; + output_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype); + if (matmul_node.OutputDefs()[0]->Shape()) { + *output_fp16_type.mutable_tensor_type()->mutable_shape() = + matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().shape(); + } + auto mnb_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_matmulnbits_out"); + NodeArg* mnb_output_arg = &graph.GetOrCreateNodeArg(mnb_out_name, &output_fp16_type); + + // --- Create MatMulNBits node --- + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", N), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), attrs); + + auto& new_node = graph.AddNode( + graph.GenerateNodeName(matmul_node.Name() + "_MatMulNBits"), + "MatMulNBits", + "Fused DQ+Cast+MatMul to MatMulNBits", + {input_a_arg}, + {mnb_output_arg}, + &attrs, + kMSDomain); + + const auto& target_provider = matmul_node.GetExecutionProviderType(); + new_node.SetExecutionProviderType(target_provider.empty() ? kCpuExecutionProvider : target_provider); + + // Add transposed weight, scale, zp to inputs + auto& input_defs = new_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); + new_node.MutableInputArgsCount().push_back(1); + + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale))); + new_node.MutableInputArgsCount().push_back(1); + + if (transposed.zero_point_proto) { + input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, *transposed.zero_point_proto, std::move(*transposed.zero_point))); + new_node.MutableInputArgsCount().push_back(1); + } - if (zp_dst) { - zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); + // --- Insert Cast on input A: matmul_input_dtype -> dq_output_dtype --- + { + NodeAttributes cast_attrs; + utils::SetNodeAttribute( + utils::MakeAttribute("to", static_cast(dq_output_dtype)), + cast_attrs); + auto& cast_node = graph.AddNode( + graph.GenerateNodeName(matmul_node.Name() + "_Cast_input_a"), + "Cast", "", + {matmul_input_a}, + {input_a_arg}, + &cast_attrs, + kOnnxDomain); + cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType()); } - auto& input_defs = replacement_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_T_tp, std::move(weight_dst))); - replacement_node.MutableInputArgsCount().push_back(1); + // --- Insert Cast on output: dq_output_dtype -> matmul_output_dtype --- + { + NodeAttributes cast_attrs; + utils::SetNodeAttribute( + utils::MakeAttribute("to", static_cast(matmul_output_dtype)), + cast_attrs); + auto& cast_node = graph.AddNode( + graph.GenerateNodeName(matmul_node.Name() + "_Cast_output"), + "Cast", "", + {mnb_output_arg}, + {const_cast(matmul_node.OutputDefs()[0])}, + &cast_attrs, + kOnnxDomain); + cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType()); + } - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_T_tp, std::move(scale_dst))); - replacement_node.MutableInputArgsCount().push_back(1); + // --- Remove original nodes --- + auto remove_node = [&graph](Node* node) { + if (node) { + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node->Index()); + } + }; - if (zp_T_tp) { - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_T_tp.value(), std::move(*zp_dst))); - replacement_node.MutableInputArgsCount().push_back(1); - } + remove_node(&matmul_node); + remove_node(cast_b_node); + remove_node(dq_node); return Status::OK(); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 02a8353707599..e112959cc58da 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -107,6 +107,20 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { concurrency::ThreadPool* intra_op_thread_pool_; }; +// Used together with DQCastMatMulToMatMulNBitsSelector. +// Handles DQ -> Cast(fp16->fp32) -> MatMul fusion to MatMulNBits, +// including optional Cast on input A and output type alignment. +struct DQCastMatMulToMatMulNBitsAction : public Action { + DQCastMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool); + + Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; + + private: + int64_t accuracy_level_; + concurrency::ThreadPool* intra_op_thread_pool_; +}; + struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index d454df3393f2b..0b04445692c9b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -297,7 +297,7 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi int64_t qdq_matmulnbits_accuracy_level, concurrency::ThreadPool* intra_op_thread_pool) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. - // DQ's weight is int4/uint4. DQ's scale is float/float16. + // DQ's weight is 2/4/8-bit int (int2/uint2, int4/uint4, int8/uint8). DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. const std::string action_name{"DQMatMulToMatMulNBits"}; @@ -316,6 +316,25 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi #else qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); #endif + + // DQ -> Cast(fp16->fp32) -> MatMul pattern. + // Handles FP16 models where Cast nodes are inserted between DQ and MatMul. + const std::string cast_action_name{"DQCastMatMulToMatMulNBits"}; + + std::unique_ptr cast_action = + std::make_unique(qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); + +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr cast_selector = + std::make_unique(providers); + qdq_selector_action_registry.RegisterSelectorAndAction(cast_action_name, + {{"MatMul", {}}}, + std::move(cast_selector), + std::move(cast_action)); +#else + qdq_selector_action_registry.RegisterAction(cast_action_name, std::move(cast_action)); +#endif } void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 05b337d9933fb..c39dfeb082e35 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -20,11 +20,27 @@ constexpr bool Is16BitIntType(int32_t data_type) { (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16); } +constexpr bool Is2BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT2) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT2); +} + constexpr bool Is4BitIntType(int32_t data_type) { return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4) || (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4); } +constexpr bool Is8BitIntType(int32_t data_type) { + return (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || + (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8); +} + +// Returns true if the data type is a sub-byte or byte quantized integer type +// suitable for MatMulNBits fusion (2, 4, or 8 bit). +constexpr bool IsNBitsIntType(int32_t data_type) { + return Is2BitIntType(data_type) || Is4BitIntType(data_type) || Is8BitIntType(data_type); +} + // adjust for an optional input/output that has an entry but does not exist int NumActualValues(const Node& node, bool input) { const auto& defs = input ? node.InputDefs() : node.OutputDefs(); @@ -542,47 +558,28 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& } } -bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, - const Node* redundant_clip_node, const std::vector& dq_nodes, - const std::vector& q_nodes) const { - if (redundant_clip_node) { - return false; - } - - // Should not have any Q nodes - if (!q_nodes.empty()) { - return false; - } - - const auto& graph = graph_viewer.GetGraph(); - - // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output - if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { - return false; - } - - // DQ must be MatMul's the second input - if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { - return false; - } - - // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 - const auto* weight_arg = dq_nodes[0]->InputDefs()[0]; - const auto* scale_arg = dq_nodes[0]->InputDefs()[1]; - const auto* zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; +// Validate that a DQ node has the correct structure for MatMulNBits fusion: +// - weight type is 2/4/8-bit int, scale type is float or float16 +// - blockwise quantization along axis 0, block_size is power-of-2 and >= 16 +// - weight/scale/zp are constant initializers with rank 2 and consistent shapes +static bool ValidateBlockwiseDQForMatMulNBits(const Graph& graph, const Node& dq_node) { + const auto* weight_arg = dq_node.InputDefs()[0]; + const auto* scale_arg = dq_node.InputDefs()[1]; + const auto* zero_point_arg = dq_node.InputDefs().size() == 3 ? dq_node.InputDefs()[2] : nullptr; int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { return false; } - if (!Is4BitIntType(dt_weight)) { + if (!IsNBitsIntType(dt_weight)) { return false; } // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 - const auto& dq_attrs = dq_nodes[0]->GetAttributes(); + const auto& dq_attrs = dq_node.GetAttributes(); if (const auto a_iter = dq_attrs.find("axis"); a_iter == dq_attrs.end() || a_iter->second.i() != 0) { return false; } @@ -627,6 +624,102 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod return true; } +bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* redundant_clip_node, const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (redundant_clip_node) { + return false; + } + + // Should not have any Q nodes + if (!q_nodes.empty()) { + return false; + } + + const auto& graph = graph_viewer.GetGraph(); + + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + return false; + } + + // DQ must be MatMul's the second input + if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + return false; + } + + return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]); +} + +std::optional +DQCastMatMulToMatMulNBitsSelector::Select(const GraphViewer& graph_viewer, const Node& node) const { + // Check EP compatibility + const std::string_view node_ep = node.GetExecutionProviderType(); + if (!compatible_providers_.empty() && + std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) { + return std::nullopt; + } + + const auto& graph = graph_viewer.GetGraph(); + + // node must be MatMul + if (node.OpType() != "MatMul") { + return std::nullopt; + } + + if (node.InputDefs().size() < 2) { + return std::nullopt; + } + + // Check input B: must be Cast(fp16->fp32) + const Node* cast_b = graph_viewer.GetProducerNode(node.InputDefs()[1]->Name()); + if (!cast_b || cast_b->OpType() != "Cast") { + return std::nullopt; + } + + const auto& cast_b_attrs = cast_b->GetAttributes(); + auto to_iter = cast_b_attrs.find("to"); + if (to_iter == cast_b_attrs.end() || + to_iter->second.i() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) { + return std::nullopt; + } + + // Cast B input must be fp16 + if (!cast_b->InputDefs()[0]->TypeAsProto() || + cast_b->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return std::nullopt; + } + + // Cast B must have exactly 1 output edge (to MatMul) and not be a graph output + if (!optimizer_utils::CheckOutputEdges(graph, *cast_b, 1)) { + return std::nullopt; + } + + // Cast B's input must come from a DQ node + const Node* dq_node = graph_viewer.GetProducerNode(cast_b->InputDefs()[0]->Name()); + if (!dq_node || dq_node->OpType() != QDQ::DQOpName) { + return std::nullopt; + } + + // DQ must have exactly 1 output edge (to Cast B) and not be a graph output + if (!optimizer_utils::CheckOutputEdges(graph, *dq_node, 1)) { + return std::nullopt; + } + + if (!ValidateBlockwiseDQForMatMulNBits(graph, *dq_node)) { + return std::nullopt; + } + + // Build selection + NodesToOptimizeIndicesBuilder builder; + builder.input_nodes.push_back(dq_node->Index()); + builder.input_nodes.push_back(cast_b->Index()); + builder.target_node = node.Index(); + + return builder.Build(); +} + bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 79c374b301442..5c10668733785 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -461,6 +461,27 @@ class DQMatMulToMatMulNBitsSelector : public BaseSelector { : BaseSelector(std::make_unique(), compatible_providers) {} }; +// Convert "DQ -> Cast(fp16->fp32) -> MatMul" to "MatMulNBits". +// Handles Cast(fp16->fp32) between DQ and MatMul on input B, and optionally on input A. +// Selection layout: +// input_nodes[0] = DQ node +// input_nodes[1] = Cast on input B (between DQ and MatMul) +// target_node = MatMul +// output_nodes = {} +class DQCastMatMulToMatMulNBitsSelector : public NodeSelector { + public: + explicit DQCastMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) + : compatible_providers_(compatible_providers.begin(), compatible_providers.end()) {} + + DQCastMatMulToMatMulNBitsSelector(DQCastMatMulToMatMulNBitsSelector&& rhs) noexcept + : compatible_providers_(std::move(rhs.compatible_providers_)) {} + + std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override; + + private: + std::vector compatible_providers_; +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index a0b44bbce62f8..c0cd40ad95ad4 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -4,6 +4,7 @@ #include #include "core/common/span_utils.h" +#include "core/common/float16.h" #include "core/framework/int4.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -343,11 +344,7 @@ void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input } TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { - // DQ contrib op schema is not updated to support blocked quantization - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); - RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + // int8/uint8 are now converted (8-bit support added), so only 16-bit and 32-bit remain as type mismatches RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); @@ -499,6 +496,195 @@ TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); } +// 8-bit DQ -> MatMul conversion to MatMulNBits(bits=8) +// Input1 +// | DQ(int8/uint8) +// \ / +// MatMul +// | DQ(int8/uint8) +// \ / +// MatMul +// | +// output +template +void RunDQMatMulConverted_8bit(const std::vector& input1_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight1_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 0.01f, 0.05f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 0.01f, 0.05f); + if constexpr (use_zp) { + auto* zp1_arg = builder.MakeInitializer(scale1_shape, + static_cast(0), static_cast(2)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, + static_cast(0), static_cast(2)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-4 /*per_sample_tolerance*/, + 1e-4 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +// 2-bit DQ -> MatMul conversion to MatMulNBits(bits=2) +// Input1 +// | DQ(int2/uint2) +// \ / +// MatMul +// | DQ(int2/uint2) +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulConverted_2bit(const std::vector& input1_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight1_shape, + T(T::min_val, T::min_val, T::min_val, T::min_val), + T(T::max_val, T::max_val, T::max_val, T::max_val)); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, + T(T::min_val, T::min_val, T::min_val, T::min_val), + T(T::max_val, T::max_val, T::max_val, T::max_val)); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 8.0f, 12.0f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp1_arg = builder.MakeInitializer(scale1_shape, T(0, 0, 0, 0), T(1, 1, 1, 1)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, T(0, 0, 0, 0), T(1, 1, 1, 1)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 25 /*opset_version*/, + 1e-4 /*per_sample_tolerance*/, + 1e-4 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_2bit) { + // 2-bit int2/uint2 DQ weights should be fused to MatMulNBits(bits=2) + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_2bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_8bit) { + // 8-bit int8/uint8 DQ weights should be fused to MatMulNBits(bits=8) + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 16, 0); + // block_size=32 + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 32, 0); + RunDQMatMulConverted_8bit({12, 32}, {32, 16}, {16, 12}, 0, 32, 0); +} + TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_Cuda) { // DQ contrib op schema is not updated to support blocked quantization RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); @@ -511,6 +697,103 @@ TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_Cuda) { RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); } +// Cast-aware DQ->MatMul fusion tests +// Pattern: DQ(int4->fp16) -> Cast(fp16->fp32) -> MatMul(fp32) +// The Cast between DQ and MatMul on input B should be handled by the +// DQCastMatMulToMatMulNBits selector-action pair. +// MatMulNBits always operates in the DQ scale dtype (fp16). +// The action always inserts Cast on input A and Cast on output. +// ORT's redundant cast elimination optimizer cleans up unnecessary casts. +// +// Input1(fp32) DQ(int4->fp16) +// | | +// \ Cast(fp16->fp32) +// \ / +// MatMul(fp32) +// | +// output(fp32) +// +// After optimization: +// Input1(fp32) -> Cast(fp32->fp16) -> MatMulNBits(fp16) -> Cast(fp16->fp32) -> output(fp32) +template +typename std::enable_if || std::is_same_v, void>::type +RunDQCastMatMulConverted(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + // DQ with fp16 scales + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* scale_arg = builder.MakeInitializer(scale_shape, + MLFloat16(0.01f), MLFloat16(0.05f)); + auto* dq_output = builder.MakeIntermediate(); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + // Cast fp16 -> fp32 + auto* cast_output = builder.MakeIntermediate(); + NodeAttributes cast_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("to", + static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)), + cast_attrs); + builder.AddNode("Cast", {dq_output}, {cast_output}, "", &cast_attrs); + + // MatMul + builder.AddNode("MatMul", {input_arg, cast_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + // B-side Cast removed. New Cast(fp32->fp16) on A and Cast(fp16->fp32) on output. + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-2 /*per_sample_tolerance*/, + 1e-2 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQCastMatMulConvertedToMatMulNBits) { + // DQ(int4->fp16) -> Cast(fp16->fp32) -> MatMul should be fused to MatMulNBits + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); + RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test