diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index dbf86e2bb7fc7..7aba3b9549f23 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -67,6 +67,7 @@ #include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h" #endif #include "core/optimizer/qdq_transformer/weight_bias_quantization.h" +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" @@ -271,6 +272,7 @@ InlinedVector> GenerateTransformers( // It runs unconditionally in InferenceSession::TransformGraph() prior to Level1 optimizers. // We also put it here with other Level1 optimizers so that it can fix things up after their changes. transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); } // add __backwardpass attribute to nodes after YieldOp, ROCm-only diff --git a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc new file mode 100644 index 0000000000000..a8b9814f1020c --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/common/common.h" +#include "core/util/qmath.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { +bool WhereDummyDq::SatisfyCondition(const Graph& graph, const Node& node) const { + if (!(node.OpType() == "Where")) { + return false; + } + const auto& where_inputs = node.InputDefs(); + const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name()); + const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name()); + + bool is_p1_dq = (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName); + bool is_p2_dq = (parent_node_2 && parent_node_2->OpType() == QDQ::DQOpName); + + // WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input + if (is_p1_dq && !parent_node_2) { + return (where_inputs[2]->Shape()->dim_size() == 0); + } + if (!parent_node_1 && is_p2_dq) { + return (where_inputs[1]->Shape()->dim_size() == 0); + } + return false; +} + +Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const { + const auto& where_inputs = node.InputDefs(); + const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name()); + const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name()); + + // With SatisfyCondition, we must have one DQ and one initializer + const Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2; + int const_idx = parent_node_1 ? 2 : 1; + + const ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr; + graph.GetInitializedTensor(dq_node->InputDefs()[1]->Name(), dq_node_scale_proto); + const ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr; + graph.GetInitializedTensor(dq_node->InputDefs()[2]->Name(), dq_node_zp_proto); + + // Dummy data initializer. + ONNX_NAMESPACE::TensorProto dummy_data_proto; + dummy_data_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_data")); + // Set data type to dq node's zp dtype + dummy_data_proto.set_data_type(dq_node_zp_proto->data_type()); + + // Dummy zero point initializer. + ONNX_NAMESPACE::TensorProto dummy_zp_proto; + dummy_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_zp")); + dummy_zp_proto.set_data_type(dq_node_zp_proto->data_type()); + + switch (dummy_zp_proto.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + int8_t zp = 0; + int8_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 1); + dummy_data_proto.set_raw_data(&dummy_data, 1); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + uint8_t zp = 0; + uint8_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 1); + dummy_data_proto.set_raw_data(&dummy_data, 1); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + int16_t zp = 0; + int16_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 2); + dummy_data_proto.set_raw_data(&dummy_data, 2); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + uint16_t zp = 0; + uint16_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 2); + dummy_data_proto.set_raw_data(&dummy_data, 2); + break; + } + default: + LOGS(logger, WARNING) << "Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16"; + return Status::OK(); + } + + // Set dummy scale to the original value + const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr; + graph.GetInitializedTensor(where_inputs[const_idx]->Name(), const_node_data_proto); + Initializer initializer(graph, *const_node_data_proto, graph.ModelPath()); + if (dq_node_scale_proto->data_type() != const_node_data_proto->data_type()) { + // WhereDummyDq fills the const value to the dummy DQ's scale + LOGS(logger, WARNING) << "Currently only support existing DQ's scale with same datatype as scalar"; + return Status::OK(); + } + + // Dummy scale initializer. + ONNX_NAMESPACE::TensorProto dummy_scale_proto; + dummy_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_scale")); + dummy_scale_proto.set_data_type(dq_node_scale_proto->data_type()); + switch (initializer.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + float* where_const_scalar = initializer.data(); + dummy_scale_proto.set_raw_data(where_const_scalar, sizeof(float)); + break; + } + default: + LOGS(logger, WARNING) << "Currently support scalar with FLOAT"; + return Status::OK(); + } + + // Start editing the graph + NodeArg& dummy_data_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_data_proto); + NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_scale_proto); + NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_zp_proto); + + ONNX_NAMESPACE::TypeProto dummy_dq_type_proto = utils::TypeProtoFromTensorProto(*const_node_data_proto); + dummy_dq_type_proto.mutable_tensor_type()->set_elem_type(const_node_data_proto->data_type()); + NodeArg& dummy_dq_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_dummy_dq"), &dummy_dq_type_proto); + Node& dummy_dq_node = + graph.AddNode( + graph.GenerateNodeArgName(node.Name() + "_dummy_dq"), + QDQ::DQOpName, + "DeQuantizeLinear from WhereDummyDq GraphTransformer", + {&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg}, + {&dummy_dq_arg}, + nullptr, + dq_node->Domain()); + + node.MutableInputDefs()[const_idx] = &dummy_dq_arg; + if (graph.GetConsumerNodes(where_inputs[const_idx]->Name()).size() == 0) { + graph.RemoveInitializedTensor(where_inputs[const_idx]->Name()); + } + graph.AddEdge(dummy_dq_node.Index(), node.Index(), 0, const_idx); + modified = true; + + return Status::OK(); +} + +Status WhereDummyDq::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (this->SatisfyCondition(graph, node)) { + ORT_RETURN_IF_ERROR(WhereDummyDq::InsertDummyDQ(node, graph, modified, logger)); + } + } + + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h new file mode 100644 index 0000000000000..3260a865f8c4b --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + @Class WhereDummyDq + + Graph transformer that inserts a dummy DQ on Where node's initializer input + to form Node Unit when Where node has one DQ and one scalar initializer input +*/ +class WhereDummyDq : public GraphTransformer { + public: + WhereDummyDq() noexcept : GraphTransformer("WhereDummyDq") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + bool SatisfyCondition(const Graph& graph, const Node& node) const; + Status InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 98640bb2f6b4c..1baa6e529cbde 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -12,6 +12,7 @@ #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/weight_bias_quantization.h" +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -3220,6 +3221,79 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 } +template +void TestWhereWithDqInput(bool is_dq_1, + bool is_dq_2, + int expected_num_where, + int expected_num_dq, + int expected_num_q, + bool expected_modified) { + auto& logger = DefaultLoggingManager().DefaultLogger(); + Model model("WhereDummyDqTester", false, logger); + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + + NodeArg* where_in1 = nullptr; + NodeArg* where_in2 = nullptr; + if (is_dq_1) { + // DQ + auto* dq_Input = builder.MakeInput({4, 3, 32}, 0.0, 1.0); + auto* dq_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* dq_zp = builder.MakeInitializer({}, 0.0, 1.0); + where_in1 = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in1}); + } else { + where_in1 = builder.MakeInitializer({}, 0.0, 1.0); + } + if (is_dq_2) { + // DQ + auto* dq_Input = builder.MakeInput({4, 3, 32}, 0.0, 1.0); + auto* dq_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* dq_zp = builder.MakeInitializer({}, 0.0, 1.0); + where_in2 = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in2}); + } else { + where_in2 = builder.MakeInitializer({}, 0.0, 1.0); + } + + // Where + auto* where_cond = builder.MakeInputBool({4, 3, 32}); + auto* where_out = builder.MakeIntermediate(); + builder.AddNode("Where", {where_cond, where_in1, where_in2}, {where_out}); + + // Q + auto* q_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* q_zp = builder.MakeInitializer({}, 0.0, 1.0); + auto* q_out = builder.MakeOutput(); + builder.AddNode("QuantizeLinear", {where_out, q_scale, q_zp}, {q_out}); + + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + + auto where_optimizer = std::make_unique(); + bool modified = false; + ASSERT_STATUS_OK(where_optimizer->Apply(graph, modified, logger)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Where"], expected_num_where); + ASSERT_EQ(op_to_count["DequantizeLinear"], expected_num_dq); + ASSERT_EQ(op_to_count["QuantizeLinear"], expected_num_q); + ASSERT_EQ(modified, expected_modified); + + return; +}; + +TEST(QDQTransformerTests, WhereDummyDqTest) { + TestWhereWithDqInput(true, true, 1, 2, 1, false); + TestWhereWithDqInput(true, false, 1, 2, 1, true); + TestWhereWithDqInput(false, true, 1, 2, 1, true); + TestWhereWithDqInput(false, false, 1, 0, 1, false); + TestWhereWithDqInput(true, true, 1, 2, 1, false); + TestWhereWithDqInput(true, false, 1, 2, 1, true); + TestWhereWithDqInput(false, true, 1, 2, 1, true); + TestWhereWithDqInput(false, false, 1, 0, 1, false); +} + TEST(QDQTransformerTests, Concat) { auto test_case = [&](const std::vector>& input_shapes, int64_t axis,