Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
continue;
}

// Require that the node's output is consumed by a single QuantizeLinear node.
// Otherwise, if only the inputs are quantized, but not the output, then this node group would not
// be considered a QDQ node unit anyway.
std::vector<const Node*> children_nodes = graph.GetConsumerNodes(node.OutputDefs()[0]->Name());

Check warning on line 49 in onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc:49: Add #include <vector> for vector<> [build/include_what_you_use] [4]
if (children_nodes.size() != 1 || children_nodes[0]->OpType() != QDQ::QOpName) {
continue;
}

Node& dq_0 = *graph.GetNode(parent_node_0->Index());
Node* dq_1 = nullptr;
const ONNX_NAMESPACE::TensorProto* weight_proto = nullptr;
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5349,6 +5349,62 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) {
#endif
}

// Tests that the WeightBiasQuantization optimizer does not process nodes that do not
// already have an output that is consumed by a single QuantizeLinear node.
TEST(QDQTransformerTests, WeightBiasQuantization_SkipIfOutputNotQuantized) {
auto test_case = [](bool add_final_reshape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 24, 67, 67}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
NodeArg* weight_arg = builder.MakeInitializer<float>({24, 1, 5, 5}, -0.1f, 0.1f);
NodeArg* bias_arg = builder.MakeInitializer<float>({24}, -0.1f, 0.1f);
NodeArg* input_dq_arg = builder.MakeIntermediate();
NodeArg* conv_output_arg = add_final_reshape ? builder.MakeIntermediate() : builder.MakeOutput();

builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.014f, static_cast<uint8_t>(127), input_dq_arg);
auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_arg, bias_arg}, {conv_output_arg});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{5, 5});
conv_node.AddAttribute("strides", std::vector<int64_t>{2, 2});
conv_node.AddAttribute("group", static_cast<int64_t>(24));
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});

// Make adding a final Reshape node configurable to test two cases:
// - Conv produces a graph output
// - Conv output is consumed by some node that is NOT a QuantizeLinear
// In either case, the WeightBiasQuantization optimizer should skip this node.
if (add_final_reshape) {
NodeArg* reshape_output_arg = builder.MakeOutput();
NodeArg* new_shape_arg = builder.Make1DInitializer<int64_t>({1, -1});
builder.AddNode("Reshape", {conv_output_arg, new_shape_arg}, {reshape_output_arg});
}
};

auto check_graph = [add_final_reshape](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(false);

// Should retain the same nodes in the original graph.
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1);
EXPECT_EQ(op_to_count["Conv"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count["Reshape"], static_cast<int>(add_final_reshape));
};

TransformerTester(build_test_case,
check_graph,
TransformerLevel::Default,
TransformerLevel::Level1,
21,
/*per_sample_tolerance*/ 0.0,
/*relative_per_sample_tolerance*/ 0.0,
std::make_unique<WeightBiasQuantization>());
};

test_case(false); // Conv produces a graph output directly
test_case(true); // Conv -> Reshape -> graph_output
}

TEST(QDQTransformerTests, WeightBiasQuantization_ConvTranspose_Weight) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
Expand Down
Loading