diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index da2e8fc37382a..fdc0818e8437b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -43,7 +43,7 @@ bool IsDQWeightSigned(int32_t dt_weight) { } // Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits. -// Used by both DQMatMulToMatMulNBitsAction and DQCastMatMulToMatMulNBitsAction. +// Used by DQMatMulToMatMulNBitsAction. struct TransposedQuantizedTensors { Tensor weight; Tensor scale; @@ -486,146 +486,6 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, return Status::OK(); } -DQCastMatMulToMatMulNBitsAction::DQCastMatMulToMatMulNBitsAction( - int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) - : accuracy_level_{accuracy_level}, - intra_op_thread_pool_{intra_op_thread_pool} { - ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); -} - -Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - // Selected nodes layout (from DQCastMatMulToMatMulNBitsSelector): - // Input(0) = DQ node - // Input(1) = Cast on input B (between DQ and MatMul) - // Target() = MatMul node - auto* dq_node = selected_nodes.Input(0); - auto* cast_b_node = selected_nodes.Input(1); - auto& matmul_node = selected_nodes.Target(); - - // --- Transpose DQ weights/scales/zp via shared helper --- - TransposedQuantizedTensors transposed; - ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( - graph, *dq_node, "fused_DQ_Cast_MatMul", intra_op_thread_pool_, transposed)); - - // MatMulNBits operates in the DQ scale dtype. - // Always insert Cast on input A (to DQ dtype) and Cast on output (DQ dtype to MatMul output dtype). - // ORT's redundant cast elimination optimizer will clean up unnecessary casts later. - - // Determine DQ output element type (e.g., fp16) - int32_t dq_output_dtype = cast_b_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - // Determine MatMul output element type (e.g., fp32) - int32_t matmul_output_dtype = matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - - const auto& dq_attrs = dq_node->GetAttributes(); - const auto* weight_arg = dq_node->InputDefs()[0]; - auto K = weight_arg->Shape()->dim(0).dim_value(); - auto N = weight_arg->Shape()->dim(1).dim_value(); - auto block_size = dq_attrs.at("block_size").i(); - int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); - auto bits = DQWeightBits(dt_weight); - - // --- Create fp16 NodeArg for MatMulNBits input A --- - NodeArg* matmul_input_a = matmul_node.MutableInputDefs()[0]; - ONNX_NAMESPACE::TypeProto input_a_fp16_type; - input_a_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype); - if (matmul_input_a->Shape()) { - *input_a_fp16_type.mutable_tensor_type()->mutable_shape() = - matmul_input_a->TypeAsProto()->tensor_type().shape(); - } - auto cast_a_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_input_a_cast"); - NodeArg* input_a_arg = &graph.GetOrCreateNodeArg(cast_a_out_name, &input_a_fp16_type); - - // --- Create fp16 NodeArg for MatMulNBits output --- - ONNX_NAMESPACE::TypeProto output_fp16_type; - output_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype); - if (matmul_node.OutputDefs()[0]->Shape()) { - *output_fp16_type.mutable_tensor_type()->mutable_shape() = - matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().shape(); - } - auto mnb_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_matmulnbits_out"); - NodeArg* mnb_output_arg = &graph.GetOrCreateNodeArg(mnb_out_name, &output_fp16_type); - - // --- Create MatMulNBits node --- - NodeAttributes attrs; - utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs); - utils::SetNodeAttribute(utils::MakeAttribute("N", N), attrs); - utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); - utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); - utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), attrs); - - auto& new_node = graph.AddNode( - graph.GenerateNodeName(matmul_node.Name() + "_MatMulNBits"), - "MatMulNBits", - "Fused DQ+Cast+MatMul to MatMulNBits", - {input_a_arg}, - {mnb_output_arg}, - &attrs, - kMSDomain); - - const auto& target_provider = matmul_node.GetExecutionProviderType(); - new_node.SetExecutionProviderType(target_provider.empty() ? kCpuExecutionProvider : target_provider); - - // Add transposed weight, scale, zp to inputs - auto& input_defs = new_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); - new_node.MutableInputArgsCount().push_back(1); - - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale))); - new_node.MutableInputArgsCount().push_back(1); - - if (transposed.zero_point_proto) { - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, *transposed.zero_point_proto, std::move(*transposed.zero_point))); - new_node.MutableInputArgsCount().push_back(1); - } - - // --- Insert Cast on input A: matmul_input_dtype -> dq_output_dtype --- - { - NodeAttributes cast_attrs; - utils::SetNodeAttribute( - utils::MakeAttribute("to", static_cast(dq_output_dtype)), - cast_attrs); - auto& cast_node = graph.AddNode( - graph.GenerateNodeName(matmul_node.Name() + "_Cast_input_a"), - "Cast", "", - {matmul_input_a}, - {input_a_arg}, - &cast_attrs, - kOnnxDomain); - cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType()); - } - - // --- Insert Cast on output: dq_output_dtype -> matmul_output_dtype --- - { - NodeAttributes cast_attrs; - utils::SetNodeAttribute( - utils::MakeAttribute("to", static_cast(matmul_output_dtype)), - cast_attrs); - auto& cast_node = graph.AddNode( - graph.GenerateNodeName(matmul_node.Name() + "_Cast_output"), - "Cast", "", - {mnb_output_arg}, - {const_cast(matmul_node.OutputDefs()[0])}, - &cast_attrs, - kOnnxDomain); - cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType()); - } - - // --- Remove original nodes --- - auto remove_node = [&graph](Node* node) { - if (node) { - graph_utils::RemoveNodeOutputEdges(graph, *node); - graph.RemoveNode(node->Index()); - } - }; - - remove_node(&matmul_node); - remove_node(cast_b_node); - remove_node(dq_node); - - return Status::OK(); -} - static std::vector GetGemmMoveInfo(bool does_q_node_exist) { NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0}; NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1}; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index e112959cc58da..02a8353707599 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -107,20 +107,6 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { concurrency::ThreadPool* intra_op_thread_pool_; }; -// Used together with DQCastMatMulToMatMulNBitsSelector. -// Handles DQ -> Cast(fp16->fp32) -> MatMul fusion to MatMulNBits, -// including optional Cast on input A and output type alignment. -struct DQCastMatMulToMatMulNBitsAction : public Action { - DQCastMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); - - Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; - - private: - int64_t accuracy_level_; - concurrency::ThreadPool* intra_op_thread_pool_; -}; - struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 0b04445692c9b..8cab6911646f2 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -7,6 +7,7 @@ #include #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" + #include "core/mlas/inc/mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" @@ -306,7 +307,12 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi intra_op_thread_pool); #if !defined(ORT_MINIMAL_BUILD) - std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider}; + // Include "" (empty string) to match nodes not yet assigned to an EP. + // For FP16 models on CPU EP, FP16 MatMul nodes are not claimed during partitioning + // (no FP16 MatMul kernel on CPU), leaving their EP unassigned. The DQ->MatMul fusion + // should still apply; the action assigns kCpuExecutionProvider to the resulting + // MatMulNBits node (which has both float and float16 CPU kernels). + std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider, ""}; std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"MatMul", {}}}, @@ -316,25 +322,6 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi #else qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); #endif - - // DQ -> Cast(fp16->fp32) -> MatMul pattern. - // Handles FP16 models where Cast nodes are inserted between DQ and MatMul. - const std::string cast_action_name{"DQCastMatMulToMatMulNBits"}; - - std::unique_ptr cast_action = - std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); - -#if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr cast_selector = - std::make_unique(providers); - qdq_selector_action_registry.RegisterSelectorAndAction(cast_action_name, - {{"MatMul", {}}}, - std::move(cast_selector), - std::move(cast_action)); -#else - qdq_selector_action_registry.RegisterAction(cast_action_name, std::move(cast_action)); -#endif } void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { @@ -416,7 +403,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( apply_context, // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. - {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider}} { + // Also accept nodes with empty EP (unassigned) so that individual selectors + // that include "" in their compatible providers can match unassigned nodes. + {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider, ""}} { } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index c39dfeb082e35..8a00fe11ff3fd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -651,75 +651,6 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]); } -std::optional -DQCastMatMulToMatMulNBitsSelector::Select(const GraphViewer& graph_viewer, const Node& node) const { - // Check EP compatibility - const std::string_view node_ep = node.GetExecutionProviderType(); - if (!compatible_providers_.empty() && - std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) { - return std::nullopt; - } - - const auto& graph = graph_viewer.GetGraph(); - - // node must be MatMul - if (node.OpType() != "MatMul") { - return std::nullopt; - } - - if (node.InputDefs().size() < 2) { - return std::nullopt; - } - - // Check input B: must be Cast(fp16->fp32) - const Node* cast_b = graph_viewer.GetProducerNode(node.InputDefs()[1]->Name()); - if (!cast_b || cast_b->OpType() != "Cast") { - return std::nullopt; - } - - const auto& cast_b_attrs = cast_b->GetAttributes(); - auto to_iter = cast_b_attrs.find("to"); - if (to_iter == cast_b_attrs.end() || - to_iter->second.i() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) { - return std::nullopt; - } - - // Cast B input must be fp16 - if (!cast_b->InputDefs()[0]->TypeAsProto() || - cast_b->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != - ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { - return std::nullopt; - } - - // Cast B must have exactly 1 output edge (to MatMul) and not be a graph output - if (!optimizer_utils::CheckOutputEdges(graph, *cast_b, 1)) { - return std::nullopt; - } - - // Cast B's input must come from a DQ node - const Node* dq_node = graph_viewer.GetProducerNode(cast_b->InputDefs()[0]->Name()); - if (!dq_node || dq_node->OpType() != QDQ::DQOpName) { - return std::nullopt; - } - - // DQ must have exactly 1 output edge (to Cast B) and not be a graph output - if (!optimizer_utils::CheckOutputEdges(graph, *dq_node, 1)) { - return std::nullopt; - } - - if (!ValidateBlockwiseDQForMatMulNBits(graph, *dq_node)) { - return std::nullopt; - } - - // Build selection - NodesToOptimizeIndicesBuilder builder; - builder.input_nodes.push_back(dq_node->Index()); - builder.input_nodes.push_back(cast_b->Index()); - builder.target_node = node.Index(); - - return builder.Build(); -} - bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 5c10668733785..79c374b301442 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -461,27 +461,6 @@ class DQMatMulToMatMulNBitsSelector : public BaseSelector { : BaseSelector(std::make_unique(), compatible_providers) {} }; -// Convert "DQ -> Cast(fp16->fp32) -> MatMul" to "MatMulNBits". -// Handles Cast(fp16->fp32) between DQ and MatMul on input B, and optionally on input A. -// Selection layout: -// input_nodes[0] = DQ node -// input_nodes[1] = Cast on input B (between DQ and MatMul) -// target_node = MatMul -// output_nodes = {} -class DQCastMatMulToMatMulNBitsSelector : public NodeSelector { - public: - explicit DQCastMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) - : compatible_providers_(compatible_providers.begin(), compatible_providers.end()) {} - - DQCastMatMulToMatMulNBitsSelector(DQCastMatMulToMatMulNBitsSelector&& rhs) noexcept - : compatible_providers_(std::move(rhs.compatible_providers_)) {} - - std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override; - - private: - std::vector compatible_providers_; -}; - // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index c0cd40ad95ad4..5d7eda39be271 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -697,33 +697,29 @@ TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits_Cuda) { RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1, DefaultCudaExecutionProvider()); } -// Cast-aware DQ->MatMul fusion tests -// Pattern: DQ(int4->fp16) -> Cast(fp16->fp32) -> MatMul(fp32) -// The Cast between DQ and MatMul on input B should be handled by the -// DQCastMatMulToMatMulNBits selector-action pair. -// MatMulNBits always operates in the DQ scale dtype (fp16). -// The action always inserts Cast on input A and Cast on output. -// ORT's redundant cast elimination optimizer cleans up unnecessary casts. +// DQ(fp16) -> MatMul fusion test +// Pattern: DQ(int4, fp16_scale) -> MatMul(fp16) +// For FP16 models on CPU EP, CPU EP doesn't claim FP16 MatMul during partitioning +// (no FP16 MatMul kernel on CPU), so the node's EP is empty "". +// The DQ->MatMul fusion should still match and fuse to MatMulNBits. // -// Input1(fp32) DQ(int4->fp16) -// | | -// \ Cast(fp16->fp32) -// \ / -// MatMul(fp32) +// Input1(fp16) DQ(int4->fp16) +// \ / +// MatMul(fp16) // | -// output(fp32) +// output(fp16) // // After optimization: -// Input1(fp32) -> Cast(fp32->fp16) -> MatMulNBits(fp16) -> Cast(fp16->fp32) -> output(fp32) +// Input1(fp16) -> MatMulNBits(fp16) -> output(fp16) template typename std::enable_if || std::is_same_v, void>::type -RunDQCastMatMulConverted(const std::vector& input1_shape, +RunDQMatMulFP16Converted(const std::vector& input1_shape, const std::vector& weight_shape, const int64_t axis, const int64_t block_size, int64_t accuracy_level) { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput(input1_shape, -1.0f, 1.0f); + auto* input_arg = builder.MakeInput(input1_shape, MLFloat16(-1.0f), MLFloat16(1.0f)); auto* output_arg = builder.MakeOutput(); // DQ with fp16 scales @@ -745,24 +741,14 @@ RunDQCastMatMulConverted(const std::vector& input1_shape, builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); } - // Cast fp16 -> fp32 - auto* cast_output = builder.MakeIntermediate(); - NodeAttributes cast_attrs; - utils::SetNodeAttribute(utils::MakeAttribute("to", - static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)), - cast_attrs); - builder.AddNode("Cast", {dq_output}, {cast_output}, "", &cast_attrs); - - // MatMul - builder.AddNode("MatMul", {input_arg, cast_output}, {output_arg}); + // MatMul (fp16) + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); }; auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(false); EXPECT_EQ(op_to_count["MatMul"], 0); - // B-side Cast removed. New Cast(fp32->fp16) on A and Cast(fp16->fp32) on output. - EXPECT_EQ(op_to_count["Cast"], 2); EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 1); EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); }; @@ -786,12 +772,12 @@ RunDQCastMatMulConverted(const std::vector& input1_shape, add_session_options_fn); } -TEST(QDQTransformerTests, DQCastMatMulConvertedToMatMulNBits) { - // DQ(int4->fp16) -> Cast(fp16->fp32) -> MatMul should be fused to MatMulNBits - RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); - RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); - RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); - RunDQCastMatMulConverted({12, 32}, {32, 16}, 0, 16, 0); +TEST(QDQTransformerTests, DQMatMulFP16ConvertedToMatMulNBits) { + // DQ(int4, fp16_scale) -> MatMul(fp16) should be fused to MatMulNBits + RunDQMatMulFP16Converted({12, 32}, {32, 16}, 0, 16, 0); + RunDQMatMulFP16Converted({12, 32}, {32, 16}, 0, 16, 0); + RunDQMatMulFP16Converted({12, 32}, {32, 16}, 0, 16, 0); + RunDQMatMulFP16Converted({12, 32}, {32, 16}, 0, 16, 0); } #endif // !defined(DISABLE_CONTRIB_OPS)