diff --git a/neural_compressor/adaptor/tf_utils/graph_rewriter/int8/fuse_matmul_requantize.py b/neural_compressor/adaptor/tf_utils/graph_rewriter/int8/fuse_matmul_requantize.py index 2060fecbc4e..9647b657d4c 100644 --- a/neural_compressor/adaptor/tf_utils/graph_rewriter/int8/fuse_matmul_requantize.py +++ b/neural_compressor/adaptor/tf_utils/graph_rewriter/int8/fuse_matmul_requantize.py @@ -588,10 +588,16 @@ def do_transformation(self): min_filter_node = None # The Min and Max of non-const weight node are from QuantizeV2's output, not valid nodes. # Add check here for excluding this case. - if ":2" not in new_node.input[6]: - max_filter_node = self.graph_info[new_node.input[6]].node - if ":1" not in new_node.input[5]: - min_filter_node = self.graph_info[new_node.input[5]].node + if len(attr_fused_ops) == 0: # single matmul case + if ":2" not in new_node.input[5]: + max_filter_node = self.graph_info[new_node.input[5]].node + if ":1" not in new_node.input[4]: + min_filter_node = self.graph_info[new_node.input[4]].node + else: + if ":2" not in new_node.input[6]: + max_filter_node = self.graph_info[new_node.input[6]].node + if ":1" not in new_node.input[5]: + min_filter_node = self.graph_info[new_node.input[5]].node last_node = self.graph_info[new_node.input[0]].node is_min_first = bool(quantized_node.attr['input_quant_mode'].s == b'MIN_FIRST') weight_node = self.graph_info[new_node.input[1]].node diff --git a/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py b/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py index 863876590a5..62269322bc5 100644 --- a/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py +++ b/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py @@ -81,6 +81,8 @@ def do_transformation(self): self.g.graph = copy.deepcopy(self.model) self.graph_info = self.g.parse_graph() + self.g.get_frame_info() + # insert QDQ pattern for op's input for op_name in quantizable_op_names: if self._ignore_insert_qdq_pattern(op_name): @@ -115,20 +117,16 @@ def do_transformation(self): computational_node = self.graph_info[computational_node_name].node weight_name = computational_node.input[1] - weight_node = self.graph_info[weight_name].node if re.search(r"\w+:\d+", weight_name): weight_node = self.graph_info[weight_name.rsplit(':', 1)[0]].node else: weight_node = self.graph_info[weight_name].node - enter_node = None if weight_node.op == 'Enter': if self.itex_mode: parent_node = self.graph_info[Helper.node_name_from_input(weight_node.input[0])].node if not parent_node.op == 'Const': continue - else: - enter_node = weight_node - weight_node = parent_node + weight_node = parent_node else: continue @@ -139,10 +137,10 @@ def do_transformation(self): else: per_channel = False weight_bit = 7 - + self._insert_qdq_pattern_for_weight_node(computational_node, weight_node, - enter_node, + weight_name, min_max_values, per_channel, weight_bit, @@ -414,7 +412,7 @@ def _insert_qdq_pattern_for_each_input(self, op_name, namespace_prefix, def _insert_qdq_pattern_for_weight_node(self, computational_node, weight_node, - enter_node, + weight_name, min_max_values, per_channel, weight_bit=7.0, @@ -504,41 +502,27 @@ def _insert_qdq_pattern_for_weight_node(self, max_node = Helper.create_constant_node(max_name, max_value, dtypes.float32, device="cpu") if "BatchMatMul" in host_op_type and "BatchMatMul" not in weight_node.op: - min_node.input.append("^" + weight_node.name) - max_node.input.append("^" + weight_node.name) + min_node.input.append("^" + weight_name) + max_node.input.append("^" + weight_name) - quant_const_enter_node = None min_enter_node = None max_enter_node = None - if enter_node: - quant_const_enter_node = Helper.create_node('Enter', \ - qint8_const_name + '_enter', [weight_node.name]) - Helper.set_attr_string(quant_const_enter_node, - 'frame_name', enter_node.attr['frame_name'].s) - Helper.set_attr_dtype(quant_const_enter_node, 'T', dtypes.float32) - Helper.set_attr_bool(quant_const_enter_node, 'is_constant', True) - Helper.set_attr_int(quant_const_enter_node, \ - 'parallel_iterations', enter_node.attr['parallel_iterations'].i) + if insert_reshape: + reshape_dims_4to3_name = qint8_const_name + "_reshape_dims_4to3_" + reshape_dims_4to3_node = Helper.create_constant_node( + reshape_dims_4to3_name, shape_convert, dtypes.int32) + reshape_4to3_name = qint8_const_name + "_reshape_4to3_" + reshape_4to3_node = Helper.create_node("Reshape", reshape_4to3_name, + [weight_node.name, reshape_dims_4to3_name]) + reshape_4to3_node.attr["T"].CopyFrom( + attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum)) quant_node = Helper.create_node( "QuantizeV2", qint8_const_name + '_quant', - [quant_const_enter_node.name, min_name, max_name]) + [reshape_4to3_name, min_name, max_name]) else: - if insert_reshape: - reshape_dims_4to3_name = qint8_const_name + "_reshape_dims_4to3_" - reshape_dims_4to3_node = Helper.create_constant_node( - reshape_dims_4to3_name, shape_convert, dtypes.int32) - reshape_4to3_name = qint8_const_name + "_reshape_4to3_" - reshape_4to3_node = Helper.create_node("Reshape", reshape_4to3_name, - [weight_node.name, reshape_dims_4to3_name]) - reshape_4to3_node.attr["T"].CopyFrom( - attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum)) - quant_node = Helper.create_node( - "QuantizeV2", qint8_const_name + '_quant', - [reshape_4to3_name, min_name, max_name]) - else: - quant_node = Helper.create_node( - "QuantizeV2", qint8_const_name + '_quant', - [weight_node.name, min_name, max_name]) + quant_node = Helper.create_node( + "QuantizeV2", qint8_const_name + '_quant', + [weight_node.name, min_name, max_name]) dequant_node = Helper.create_node( "Dequantize", base_name + '_dequant', @@ -549,10 +533,10 @@ def _insert_qdq_pattern_for_weight_node(self, Helper.set_attr_dtype(dequant_node, "T", dtypes.qint8) Helper.set_attr_string(dequant_node, "mode", b"SCALED") if per_channel: - if host_op_type == 'Conv2D' or host_op_type == 'Conv2DBackpropInput': + if host_op_type in ('Conv2D', 'Conv2DBackpropInput'): Helper.set_attr_int(quant_node, 'axis', 3) Helper.set_attr_int(dequant_node, 'axis', 3) - elif host_op_type == 'Conv3D' or host_op_type == 'Conv3DBackpropInputV2': + elif host_op_type in ('Conv3D', 'Conv3DBackpropInputV2'): Helper.set_attr_int(quant_node, 'axis', 4) Helper.set_attr_int(dequant_node, 'axis', 4) elif host_op_type == 'MatMul': @@ -584,25 +568,24 @@ def _insert_qdq_pattern_for_weight_node(self, self.g_weight.add_node(reshape_3to4_node, dequant_node.name, [computational_node.name]) computational_node.input[1] = reshape_3to4_node.name else: - if enter_node: + if weight_node.name in self.g.parent_frame_details and self.g.parent_frame_details[weight_node.name]: min_enter_node = Helper.create_node('Enter', min_name + '_enter', [min_name]) - Helper.set_attr_string(min_enter_node, - 'frame_name', enter_node.attr['frame_name'].s) + Helper.set_attr_string(min_enter_node, 'frame_name', + self.g.parent_frame_details[weight_node.name].attr['frame_name'].s) Helper.set_attr_dtype(min_enter_node, 'T', dtypes.float32) Helper.set_attr_bool(min_enter_node, 'is_constant', True) Helper.set_attr_int(min_enter_node, 'parallel_iterations', \ - enter_node.attr['parallel_iterations'].i) + self.g.parent_frame_details[weight_node.name].attr['parallel_iterations'].i) max_enter_node = Helper.create_node('Enter', max_name + '_enter', [max_name]) - Helper.set_attr_string(max_enter_node, - 'frame_name', enter_node.attr['frame_name'].s) + Helper.set_attr_string(max_enter_node, 'frame_name', + self.g.parent_frame_details[weight_node.name].attr['frame_name'].s) Helper.set_attr_dtype(max_enter_node, 'T', dtypes.float32) Helper.set_attr_bool(max_enter_node, 'is_constant', True) Helper.set_attr_int(max_enter_node, 'parallel_iterations',\ - enter_node.attr['parallel_iterations'].i) + self.g.parent_frame_details[weight_node.name].attr['parallel_iterations'].i) - self.g_weight.add_node(quant_const_enter_node, weight_node.name, [quant_node.name]) - self.g_weight.add_node(quant_node, quant_const_enter_node.name, []) + self.g_weight.add_node(quant_node, weight_name, []) self.g_weight.add_node(min_node, None, [min_enter_node.name]) self.g_weight.add_node(max_node, None, [max_enter_node.name]) self.g_weight.add_node(min_enter_node, min_node.name, [quant_node.name]) @@ -610,7 +593,7 @@ def _insert_qdq_pattern_for_weight_node(self, self.g_weight.add_node(dequant_node, quant_node.name, [computational_node.name]) computational_node.input[1] = dequant_node.name else: - self.g_weight.add_node(quant_node, weight_node.name, []) + self.g_weight.add_node(quant_node, weight_name, []) self.g_weight.add_node(min_node, None, [quant_node.name]) self.g_weight.add_node(max_node, None, [quant_node.name]) self.g_weight.add_node(dequant_node, quant_node.name, [computational_node.name])