diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 8caa67f266266..4efaec325292a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -13,6 +13,39 @@ namespace onnxruntime { +/** + * Checks whether or not the output path from a given node leads to a QuantizeLinear op, optionally, with no + * branching ReLU or Clip op in between. See also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc. + * + * @param node The starting node to check the output path from. + * @param graph The graph containing the nodes. + * + * @return true if the path exist, false otherwise. + */ +static bool IsNoBranchPathToQuantizeLinear(const Node& node, const Graph& graph) { + const Node* current = &node; + while (true) { + // Conv / ConvTranspose / Gemm produces single output + if (current->OutputDefs().size() != 1) { + return false; + } + const std::vector& consumers = graph.GetConsumerNodes(current->OutputDefs()[0]->Name()); + // Branching or no consumer: not eligible + if (consumers.size() != 1) { + return false; + } + const Node* consumer = consumers[0]; + if (consumer->OpType() == QDQ::QOpName) { + return true; + } + // Allow ReLU or Clip, see also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc. + if (consumer->OpType() != "Relu" && consumer->OpType() != "Clip") { + return false; + } + current = consumer; + } +} + Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { const GraphViewer graph_viewer{graph}; @@ -43,11 +76,8 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph continue; } - // Require that the node's output is consumed by a single QuantizeLinear node. - // Otherwise, if only the inputs are quantized, but not the output, then this node group would not - // be considered a QDQ node unit anyway. - std::vector children_nodes = graph.GetConsumerNodes(node.OutputDefs()[0]->Name()); - if (children_nodes.size() != 1 || children_nodes[0]->OpType() != QDQ::QOpName) { + // Check if the output path leads to QuantizeLinear with optionally ReLU or Clip op in between. + if (!IsNoBranchPathToQuantizeLinear(node, graph)) { continue; } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index a7d993bc54642..83c226115aa84 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -238,7 +238,7 @@ class BaseOpBuilder : public IOpBuilder { } // Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end] - void ReArranagePads(std::vector& pads) const { + void ReArrangePads(std::vector& pads) const { auto pads_size = pads.size(); auto middle_pos = pads_size / 2; std::vector first_half(pads.begin(), pads.begin() + middle_pos); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 7391fbffccd8e..541ca5ca7ab14 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -24,7 +24,6 @@ static Status GetOnnxConvType(const std::string& onnx_op_type, OnnxConvType& con } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unsupported ONNX convolution op type: ", onnx_op_type.c_str()); } - return Status::OK(); } @@ -171,7 +170,7 @@ Status ConvOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return ProcessConv2D3DInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation); } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D(rank 5), 2D (rank 4) or 1D (rank 3) inputs."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D (rank 5), 2D (rank 4) or 1D (rank 3) inputs."); } Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, @@ -712,7 +711,7 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } } - ReArranagePads(pads); + ReArrangePads(pads); uint32_t pad_size = narrow(pads.size() / 2); QnnParamWrapper pad_amount_paramwrapper(node_unit.Index(), node_unit.Name(), QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, {pad_size, 2}, std::move(pads)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 404d3c402c21e..d2b1434c1c896 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -193,7 +193,7 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap [](int64_t item) { return SafeInt(item); }); // Onnx format is begin_0, begin_1, ..., end_0, end_1, ... // Qnn format is begin_0, end_0, begin_1, end_1, ... - ReArranagePads(pad_amount); + ReArrangePads(pad_amount); std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 7970f5a12c1bf..851c65aa1c1a3 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -195,7 +195,7 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, } } } - ReArranagePads(pad_amount); + ReArrangePads(pad_amount); // Param: rounding_mode. rounding_mode = node_helper.Get("ceil_mode", rounding_mode); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 22a5d5d43df47..1e4ba6afe6f0b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -158,7 +158,7 @@ bool QnnModelWrapper::CreateQnnInputOutputTensors(const std::string& qnn_node_na return false; } - // During graph patitioning, we only need to do op validation, it's not required to create Qnn graph tensor + // During graph partitioning, we only need to do op validation, it's not required to create Qnn graph tensor // We only need to create the Qnn graph tensor during Compile to create Qnn graph if (!do_op_validation) { std::string error_string; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1baa6e529cbde..4196ed280a993 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -5444,8 +5444,59 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) { #endif } +// Tests that the WeightBiasQuantization optimizer still processes nodes that contain a type-preserving no +// branch ReLU op to QuantizeLinear e.g., Q -> DQ -> Conv (w/ float weight initializer) -> ReLU -> Q -> DQ +TEST(QDQTransformerTests, WeightBiasQuantization_ConvWithReLU) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_fp32 = builder.MakeInput({1, 1, 4, 4}, -1.0f, 1.0f); + NodeArg* weight_fp32 = builder.MakeInitializer({2, 1, 3, 3}, -1.0f, 1.0f); + NodeArg* input_q = builder.MakeIntermediate(); + NodeArg* input_dq = builder.MakeIntermediate(); + NodeArg* conv_fp32 = builder.MakeIntermediate(); + NodeArg* relu_fp32 = builder.MakeIntermediate(); + NodeArg* relu_q = builder.MakeIntermediate(); + NodeArg* relu_dq = builder.MakeOutput(); + builder.AddQuantizeLinearNode(input_fp32, 0.18f, static_cast(127), input_q, use_contrib_qdq); + builder.AddDequantizeLinearNode(input_q, 0.18f, static_cast(127), input_dq, use_contrib_qdq); + auto& conv_node = builder.AddNode("Conv", {input_dq, weight_fp32}, {conv_fp32}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + conv_node.AddAttribute("strides", std::vector{1, 1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + builder.AddNode("Relu", {conv_fp32}, {relu_fp32}); + builder.AddQuantizeLinearNode(relu_fp32, 0.69f, static_cast(127), relu_q, use_contrib_qdq); + builder.AddDequantizeLinearNode(relu_q, 0.69f, static_cast(127), relu_dq, use_contrib_qdq); + }; + + // Conv's weights should be quantized and folded, one additional Q/DQ pair inserted for weight + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"] + op_to_count["com.microsoft.QuantizeLinear"], 2 + 1); + EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2 + 1); + EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count["Relu"], 1); + }; + + TransformerTester(build_test_case, + check_transformed_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version=*/20, + /*per_sample_tolerance=*/0.01, + /*relative_per_sample_tolerance=*/0.01, + /*transformer=*/std::make_unique()); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + // Tests that the WeightBiasQuantization optimizer does not process nodes that do not -// already have an output that is consumed by a single QuantizeLinear node. +// already have an output that is consumed by a valid path to QuantizeLinear node. TEST(QDQTransformerTests, WeightBiasQuantization_SkipIfOutputNotQuantized) { auto test_case = [](bool add_final_reshape) { auto build_test_case = [&](ModelTestBuilder& builder) {