diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index a9f6420d6ac3b..effd13abc3e3a 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -346,7 +346,7 @@ static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit, auto op_of_quantized_layer = node_unit.Outputs(); for (auto& itr : op_of_quantized_layer) { auto it = graph_op_data_type.find(itr.node_arg.Name()); - if (it != graph_op_data_type.end() && it->second == "tensor(uint8)") { + if (it != graph_op_data_type.end() && (it->second == "tensor(uint8)" || it->second == "tensor(uint16)")) { return true; } } @@ -369,6 +369,11 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, graph_op_data_type[src_graph.GetNodeArg(ops->Name())->Name()] = ops->Type()->data(); } + // check If any quantized node feeds into the src graph output + if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) { + return true; + } + // If UInt16 Q, don't keep it if (GetQDQDataType(q_node) == DT_UINT16 || GetQDQDataType(q_node) == DT_INT16) { reason = SkipReason::Int16QDQ; @@ -381,9 +386,7 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, } else if (op_type == "Add") { // Add keeps all Qs return true; - } else if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) { - return true; - } else { + } else { // Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false); }