diff --git a/.azure-pipelines/scripts/ut/env_setup.sh b/.azure-pipelines/scripts/ut/env_setup.sh index 07fa00a8d35..6a9fd879fad 100644 --- a/.azure-pipelines/scripts/ut/env_setup.sh +++ b/.azure-pipelines/scripts/ut/env_setup.sh @@ -20,7 +20,7 @@ echo "mxnet version is $mxnet_version" if [[ "${tensorflow_version}" == *"-official" ]]; then pip install tensorflow==${tensorflow_version%-official} elif [[ "${tensorflow_version}" == "spr-base" ]]; then - pip install /tf_dataset/tf_binary/tensorflow*.whl + pip install /tf_dataset/tf_binary/221125/tensorflow*.whl if [[ $? -ne 0 ]]; then exit 1 fi diff --git a/neural_compressor/adaptor/tensorflow.yaml b/neural_compressor/adaptor/tensorflow.yaml index 5502158a443..36e6a674edf 100644 --- a/neural_compressor/adaptor/tensorflow.yaml +++ b/neural_compressor/adaptor/tensorflow.yaml @@ -299,6 +299,7 @@ 'Dequantize + DepthwiseConv2dNative + Add + Relu6 + QuantizeV2', 'Dequantize + DepthwiseConv2dNative + BiasAdd + QuantizeV2', 'Dequantize + FusedBatchNormV3 + Relu + QuantizeV2', + 'Dequantize + FusedBatchNormV3 + LeakyRelu + QuantizeV2', 'Dequantize + _MklFusedInstanceNorm + Relu + QuantizeV2', 'Dequantize + _MklFusedInstanceNorm + LeakyRelu + QuantizeV2', 'Dequantize + Conv2DBackpropInput + BiasAdd + QuantizeV2', diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_bn.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_bn.py index 9dbe1c82f0a..f36b02a3e94 100644 --- a/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_bn.py +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_bn.py @@ -31,8 +31,9 @@ def __init__(self, **kwargs): reverse=True) if self.new_api: self.fusion_mapping = { + 'FusedBatchNormV3': self.apply_newly_bn_relu_fusion, 'FusedBatchNormV3Relu': self.apply_newly_bn_relu_fusion, - 'FusedBatchNormV3': self.apply_newly_bn_relu_fusion + 'FusedBatchNormV3LeakyRelu': self.apply_newly_bn_leakyrelu_fusion } else: self.fusion_mapping = {} @@ -75,8 +76,7 @@ def apply_newly_bn_relu_fusion(self, match_node_name): [output_min_node_name] + [output_max_node_name] + control_inputs output_min_node = helper.create_constant_node(output_min_node_name, -1., dtypes.float32) output_max_node = helper.create_constant_node(output_max_node_name, 1., dtypes.float32) - quantized_bn_node = helper.create_node(node_op, quantized_node_name, - quantized_node_input_names) + quantized_bn_node = helper.create_node(node_op, quantized_node_name, quantized_node_input_names) if relu_node_name is not None: helper.set_attr_string(quantized_bn_node, "activation_mode", b'Relu') if self.node_name_mapping[offset_name].node.op == "Const": @@ -141,6 +141,108 @@ def apply_newly_bn_relu_fusion(self, match_node_name): new_node.CopyFrom(node) self.add_output_graph_node(new_node) + def apply_newly_bn_leakyrelu_fusion(self, match_node_name): + matched_node = self.node_name_mapping[match_node_name[0]] + skip_node_name = match_node_name[1:] + control_inputs, normal_inputs = self._get_node_input( + matched_node.node.name) + scale_name = normal_inputs[1] + offset_name = normal_inputs[2] + mean_name = normal_inputs[3] + variance_name = normal_inputs[4] + + all_input_names = self._add_eightbit_prologue_nodes(matched_node.node.name) + all_input_names = [ + all_input_names[0], + scale_name, + offset_name, + mean_name, + variance_name, + all_input_names[1], + all_input_names[2] + ] + + for _, node in enumerate(self.input_graph.node): + if node.name in skip_node_name: + self.logger.debug("skip node {}".format(node.name)) + elif node.name == match_node_name[0]: + self.logger.debug("Matched node {} with input {}.".format(node.name, node.input)) + leakyrelu_node_name = match_node_name[1] + node_op = '_QuantizedFusedBatchNorm' + quantized_node_name = node.name + "_eightbit_quantized_bn" + output_min_node_name = quantized_node_name + "_input7_output_min" + output_max_node_name = quantized_node_name + "_input8_output_max" + quantized_node_input_names = all_input_names + \ + [output_min_node_name] + [output_max_node_name] + control_inputs + output_min_node = helper.create_constant_node(output_min_node_name, -1., dtypes.float32) + output_max_node = helper.create_constant_node(output_max_node_name, 1., dtypes.float32) + quantized_bn_node = helper.create_node(node_op, quantized_node_name, quantized_node_input_names) + + helper.set_attr_string(quantized_bn_node, "activation_mode", b'LeakyRelu') + helper.copy_attr(quantized_bn_node, "alpha", \ + self.node_name_mapping[leakyrelu_node_name].node.attr["alpha"]) + if self.node_name_mapping[offset_name].node.op == "Const": + helper.set_attr_bool(quantized_bn_node, "is_offset_const", True) + else: + helper.set_attr_bool(quantized_bn_node, "is_offset_const", False) + if self.node_name_mapping[mean_name].node.op == "Const": + helper.set_attr_bool(quantized_bn_node, "is_mean_const", True) + else: + helper.set_attr_bool(quantized_bn_node, "is_mean_const", False) + helper.set_attr_dtype(quantized_bn_node, "T", dtypes.qint8) + helper.set_attr_dtype(quantized_bn_node, "U", dtypes.float32) + helper.set_attr_dtype(quantized_bn_node, "Tout", dtypes.qint8) + + """ + # 0. x + # 1. scale + # 2. offset + # 3. mean + # 4. variance + # 5. x_min + # 6. x_max + # 7. {output_min} + # 8. {output_max} + """ + helper.set_attr_type_list(quantized_bn_node, 'input_types', [ + dtypes.qint8.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + ]) + + + """ + # 0. output + # 1. output_min + # 2. output_max + """ + helper.set_attr_type_list(quantized_bn_node, 'out_types', [ + dtypes.qint8.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + ]) + self.add_output_graph_node(output_min_node) + self.add_output_graph_node(output_max_node) + self.add_output_graph_node(quantized_bn_node) + self._intel_cpu_add_dequantize_result_node( + quantized_output_name = quantized_node_name, + original_node_name = match_node_name[-1], + dtype = dtypes.qint8, + min_tensor_index = 1, + performance_only=self.performance_only + ) + + else: + new_node = node_def_pb2.NodeDef() + new_node.CopyFrom(node) + self.add_output_graph_node(new_node) + def get_longest_fuse(self): self._get_op_list() real_patterns = [pattern[1 :-1] for pattern in self.sorted_patterns] diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/quantize_graph_bn.py b/neural_compressor/adaptor/tf_utils/quantize_graph/quantize_graph_bn.py index 5bf86c74e72..9a425505dc1 100644 --- a/neural_compressor/adaptor/tf_utils/quantize_graph/quantize_graph_bn.py +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/quantize_graph_bn.py @@ -31,8 +31,9 @@ def __init__(self, **kwargs): reverse=True) if self.new_api: self.fusion_mapping = { + 'FusedBatchNormV3': self.apply_newly_bn_relu_fusion, 'FusedBatchNormV3Relu': self.apply_newly_bn_relu_fusion, - 'FusedBatchNormV3': self.apply_newly_bn_relu_fusion + 'FusedBatchNormV3LeakyRelu': self.apply_newly_bn_leakyrelu_fusion } else: self.fusion_mapping = {} @@ -75,8 +76,7 @@ def apply_newly_bn_relu_fusion(self, match_node_name): [output_min_node_name] + [output_max_node_name] + control_inputs output_min_node = helper.create_constant_node(output_min_node_name, -1., dtypes.float32) output_max_node = helper.create_constant_node(output_max_node_name, 1., dtypes.float32) - quantized_bn_node = helper.create_node(node_op, quantized_node_name, - quantized_node_input_names) + quantized_bn_node = helper.create_node(node_op, quantized_node_name, quantized_node_input_names) if relu_node_name is not None: helper.set_attr_string(quantized_bn_node, "activation_mode", b'Relu') if self.node_name_mapping[offset_name].node.op == "Const": @@ -140,6 +140,108 @@ def apply_newly_bn_relu_fusion(self, match_node_name): new_node.CopyFrom(node) self.add_output_graph_node(new_node) + def apply_newly_bn_leakyrelu_fusion(self, match_node_name): + matched_node = self.node_name_mapping[match_node_name[0]] + skip_node_name = match_node_name[1:] + control_inputs, normal_inputs = self._get_node_input( + matched_node.node.name) + scale_name = normal_inputs[1] + offset_name = normal_inputs[2] + mean_name = normal_inputs[3] + variance_name = normal_inputs[4] + + all_input_names = self._add_eightbit_prologue_nodes(matched_node.node.name) + all_input_names = [ + all_input_names[0], + scale_name, + offset_name, + mean_name, + variance_name, + all_input_names[1], + all_input_names[2] + ] + + for _, node in enumerate(self.input_graph.node): + if node.name in skip_node_name: + self.logger.debug("skip node {}".format(node.name)) + elif node.name == match_node_name[0]: + self.logger.debug("Matched node {} with input {}.".format(node.name, node.input)) + leakyrelu_node_name = match_node_name[1] + node_op = '_QuantizedFusedBatchNorm' + quantized_node_name = node.name + "_eightbit_quantized_bn" + output_min_node_name = quantized_node_name + "_input7_output_min" + output_max_node_name = quantized_node_name + "_input8_output_max" + quantized_node_input_names = all_input_names + \ + [output_min_node_name] + [output_max_node_name] + control_inputs + output_min_node = helper.create_constant_node(output_min_node_name, -1., dtypes.float32) + output_max_node = helper.create_constant_node(output_max_node_name, 1., dtypes.float32) + quantized_bn_node = helper.create_node(node_op, quantized_node_name, quantized_node_input_names) + + helper.set_attr_string(quantized_bn_node, "activation_mode", b'LeakyRelu') + helper.copy_attr(quantized_bn_node, "alpha", \ + self.node_name_mapping[leakyrelu_node_name].node.attr["alpha"]) + if self.node_name_mapping[offset_name].node.op == "Const": + helper.set_attr_bool(quantized_bn_node, "is_offset_const", True) + else: + helper.set_attr_bool(quantized_bn_node, "is_offset_const", False) + if self.node_name_mapping[mean_name].node.op == "Const": + helper.set_attr_bool(quantized_bn_node, "is_mean_const", True) + else: + helper.set_attr_bool(quantized_bn_node, "is_mean_const", False) + helper.set_attr_dtype(quantized_bn_node, "T", dtypes.qint8) + helper.set_attr_dtype(quantized_bn_node, "U", dtypes.float32) + helper.set_attr_dtype(quantized_bn_node, "Tout", dtypes.qint8) + + """ + # 0. x + # 1. scale + # 2. offset + # 3. mean + # 4. variance + # 5. x_min + # 6. x_max + # 7. {output_min} + # 8. {output_max} + """ + helper.set_attr_type_list(quantized_bn_node, 'input_types', [ + dtypes.qint8.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + ]) + + + """ + # 0. output + # 1. output_min + # 2. output_max + """ + helper.set_attr_type_list(quantized_bn_node, 'out_types', [ + dtypes.qint8.as_datatype_enum, + dtypes.float32.as_datatype_enum, + dtypes.float32.as_datatype_enum, + ]) + self.add_output_graph_node(output_min_node) + self.add_output_graph_node(output_max_node) + self.add_output_graph_node(quantized_bn_node) + self._intel_cpu_add_dequantize_result_node( + quantized_output_name = quantized_node_name, + original_node_name = match_node_name[-1], + dtype = dtypes.qint8, + min_tensor_index = 1, + performance_only=self.performance_only + ) + + else: + new_node = node_def_pb2.NodeDef() + new_node.CopyFrom(node) + self.add_output_graph_node(new_node) + def get_longest_fuse(self): self._get_op_list() matched_rule, matched_node_name = self._is_match(self.sorted_patterns) diff --git a/test/tfnewapi/test_tensorflow_graph_qdq_bn_fusion.py b/test/tfnewapi/test_tensorflow_graph_qdq_bn_fusion.py index acfbd049072..d99a48c1803 100644 --- a/test/tfnewapi/test_tensorflow_graph_qdq_bn_fusion.py +++ b/test/tfnewapi/test_tensorflow_graph_qdq_bn_fusion.py @@ -12,6 +12,8 @@ from tensorflow.python.framework import dtypes from neural_compressor.adaptor.tf_utils.util import disable_random from neural_compressor.utils.utility import CpuInfo +from neural_compressor.experimental import Quantization, common +from neural_compressor.utils import logger def build_fake_yaml_1(): fake_yaml_1 = ''' @@ -91,7 +93,7 @@ def tearDownClass(self): @disable_random() def test_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): - logging.getLogger().info("test_depthwiseconv_biasadd_relu_fusion") + logger.info("test_bn_relu_depthwiseconv_biasadd_relu6_fusion") x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") conv_weights = tf.compat.v1.get_variable("weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer()) @@ -107,7 +109,7 @@ def test_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): sess=sess, input_graph_def=sess.graph_def, output_node_names=[out_name]) - from neural_compressor.experimental import Quantization, common + quantizer = Quantization('fake_yaml_1.yaml') dataset = quantizer.dataset('dummy', shape=(100, 56, 56, 16), label=True) quantizer.eval_dataloader = common.DataLoader(dataset) @@ -137,7 +139,7 @@ def test_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): @disable_random() def test_training_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): - logging.getLogger().info("test_depthwiseconv_biasadd_relu_fusion") + logger.info("test_training_bn_relu_depthwiseconv_biasadd_relu6_fusion") x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") conv_weights = tf.compat.v1.get_variable("weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer()) @@ -153,7 +155,7 @@ def test_training_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): sess=sess, input_graph_def=sess.graph_def, output_node_names=[out_name]) - from neural_compressor.experimental import Quantization, common + quantizer = Quantization('fake_yaml_1.yaml') dataset = quantizer.dataset('dummy', shape=(100, 56, 56, 16), label=True) quantizer.eval_dataloader = common.DataLoader(dataset) @@ -177,9 +179,68 @@ def test_training_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): if bf16_enabled: self.assertEqual(bf16_bn_num, 1) + @disable_random() + def test_bn_leakyrelu_conv_biasadd_relu(self): + logger.info("test_bn_leakyrelu_conv_biasadd_relu") + x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") + conv_weights = tf.compat.v1.get_variable("weight", [3, 3, 16, 16], + initializer=tf.compat.v1.random_normal_initializer()) + normed_0 = tf.compat.v1.layers.batch_normalization(x) + leaky_relu = tf.nn.leaky_relu(normed_0, alpha=0.3, name='op_to_store_0') + conv = tf.nn.conv2d(leaky_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") + normed_1 = tf.compat.v1.layers.batch_normalization(conv) + relu = tf.nn.relu(normed_1, name='op_to_store_1') + out_name = relu.name.split(':')[0] + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.global_variables_initializer()) + output_graph_def = graph_util.convert_variables_to_constants( + sess=sess, + input_graph_def=sess.graph_def, + output_node_names=[out_name]) + + quantizer = Quantization('fake_yaml_1.yaml') + dataset = quantizer.dataset('dummy', shape=(100, 56, 56, 16), label=True) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.calib_dataloader = common.DataLoader(dataset) + + quantizer.model = output_graph_def + output_graph = quantizer.fit() + conv_input_type = True + found_fusion = True + qbn_num = 0 + dq_num = 0 + qbn_output_max_name = 'batch_normalization/FusedBatchNormV3_eightbit_quantized_bn/frozen_bn_output_max' + for i in output_graph.graph_def.node: + if i.op == '_FusedQuantizedConv2D' \ + and i.attr['Thost_inputs'].list.type != [11, 11, 1, 1, 1, 1, 1, 1, 1]: + conv_input_type = False + break + if i.op in ['Relu', 'LeakyRelu', 'FusedBatchNormV3']: + found_fusion = False + break + if i.op == '_QuantizedFusedBatchNorm': + is_offset_const = i.attr["is_offset_const"].b + is_mean_const = i.attr["is_mean_const"].b + qbn_alpha = i.attr["alpha"].f + frozen_qbn_output_max = i.input[8] + qbn_num += 1 + if i.name == qbn_output_max_name: + frozen_qbn_output_max_value = i.attr["value"].tensor.float_val[0] + if i.op == 'Dequantize': + dq_num += 1 + self.assertEqual(conv_input_type, True) + self.assertEqual(found_fusion, True) + self.assertEqual(qbn_num, 1) + self.assertEqual(dq_num, 1) + self.assertEqual(is_offset_const, True) + self.assertEqual(is_mean_const, True) + self.assertEqual(round(qbn_alpha, 7), 0.3) + self.assertEqual(frozen_qbn_output_max, qbn_output_max_name) + self.assertGreater(frozen_qbn_output_max_value, 126) + @disable_random() def test_bn_relu_conv_biasadd_relu(self): - logging.getLogger().info("test_conv_biasadd_relu_fusion") + logger.info("test_bn_relu_conv_biasadd_relu") x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") conv_weights = tf.compat.v1.get_variable("weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer()) @@ -195,7 +256,7 @@ def test_bn_relu_conv_biasadd_relu(self): sess=sess, input_graph_def=sess.graph_def, output_node_names=[out_name]) - from neural_compressor.experimental import Quantization, common + quantizer = Quantization('fake_yaml_1.yaml') dataset = quantizer.dataset('dummy', shape=(100, 56, 56, 16), label=True) quantizer.eval_dataloader = common.DataLoader(dataset) @@ -236,7 +297,7 @@ def test_bn_relu_conv_biasadd_relu(self): @disable_random() def test_bn_performance_only_false(self): - logging.getLogger().info("test_conv_biasadd_relu_fusion") + logger.info("test_bn_performance_only_false") x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") conv_weights = tf.compat.v1.get_variable("weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer()) @@ -252,7 +313,7 @@ def test_bn_performance_only_false(self): sess=sess, input_graph_def=sess.graph_def, output_node_names=[out_name]) - from neural_compressor.experimental import Quantization, common + quantizer = Quantization('fake_yaml_2.yaml') dataset = quantizer.dataset('dummy', shape=(100, 56, 56, 16), label=True) quantizer.eval_dataloader = common.DataLoader(dataset) @@ -281,7 +342,7 @@ def test_bn_performance_only_false(self): @disable_random() def test_bnex_performance_only_false(self): - logging.getLogger().info("test_conv_biasadd_relu_fusion") + logger.info("test_bnex_performance_only_false") x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") conv_weights_0 = tf.compat.v1.get_variable("weight_0", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer()) @@ -312,7 +373,7 @@ def test_bnex_performance_only_false(self): if node.name == "batch_normalization_1/FusedBatchNormV3": node.op = "_FusedBatchNormEx" node.attr["activation_mode"].CopyFrom(attr_value_pb2.AttrValue(s=b"Relu")) - from neural_compressor.experimental import Quantization, common + quantizer = Quantization('fake_yaml_2.yaml') dataset = quantizer.dataset('dummy', shape=(100, 56, 56, 16), label=True) quantizer.eval_dataloader = common.DataLoader(dataset)