diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index a7b3dd5792a02..a4c7a2a2bf8df 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -387,23 +387,27 @@ def quantize(self): break _logger.info("Finish sampling stage, all batch: " + str(batch_id)) - if self._round_type == 'adaround': - self._adaround_apply() - - self._reset_activation_persistable() if self._algo == 'avg': for var_name in self._quantized_act_var_name: self._quantized_threshold[var_name] = \ np.array(self._quantized_var_avg[var_name]).mean() if self._algo in ["KL", "hist"]: self._calculate_kl_hist_threshold() - if self._algo in ["KL", "abs_max", "hist", "avg", "mse", "emd"]: - self._update_program() - else: + + if self._round_type == 'adaround': + self._adaround_apply() + + self._reset_activation_persistable() + + if self._algo is 'min_max': self._save_input_threhold() + else: + self._update_program() + # save out_threshold for quantized ops. if not self._onnx_format: self._save_output_threshold() + if any(op_type in self._quantizable_op_type for op_type in self._dynamic_quantize_op_type): self._collect_dynamic_quantize_op_threshold( @@ -428,6 +432,7 @@ def quantize(self): return self._program def _adaround_apply(self): + assert self._algo != "min_max", "The algo should not be min_max." if self._algo in ["KL", "hist"]: scale_dict = self._quantized_var_threshold else: diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py index 58a430eb96406..85cabb6b5e9b7 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py @@ -173,7 +173,8 @@ def generate_quantized_model(self, is_use_cache_file=False, is_optimize_model=False, batch_size=10, - batch_nums=10): + batch_nums=10, + onnx_format=False): place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -190,14 +191,28 @@ def generate_quantized_model(self, round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, + onnx_format=onnx_format, is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model_path) - def run_test(self, model_name, model_url, model_md5, data_name, data_url, - data_md5, algo, round_type, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, infer_iterations, quant_iterations): + def run_test(self, + model_name, + model_url, + model_md5, + data_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + infer_iterations, + quant_iterations, + onnx_format=False): fp32_model_path = self.download_model(model_url, model_md5, model_name) fp32_model_path = os.path.join(fp32_model_path, model_name) @@ -211,10 +226,10 @@ def run_test(self, model_name, model_url, model_md5, data_name, data_url, print("Start post training quantization for {0} on {1} samples ...". format(model_name, quant_iterations)) - self.generate_quantized_model(fp32_model_path, data_path, algo, - round_type, quantizable_op_type, - is_full_quantize, is_use_cache_file, - is_optimize_model, quant_iterations) + self.generate_quantized_model( + fp32_model_path, data_path, algo, round_type, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + quant_iterations, onnx_format) print("Start INT8 inference for {0} on {1} samples ...".format( model_name, infer_iterations)) @@ -278,5 +293,42 @@ def test_post_training_kl(self): diff_threshold, infer_iterations, quant_iterations) +class TestPostTrainingKLForMnistONNXFormat(TestPostTrainingQuantization): + def test_post_training_kl_onnx_format(self): + model_name = "nlp_lstm_fp32_model" + model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" + model_md5 = "519b8eeac756e7b4b7bcb2868e880452" + data_name = "quant_lstm_input_data" + data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" + data_md5 = "add84c754e9b792fea1fbd728d134ab7" + algo = "KL" + round_type = "round" + quantizable_op_type = ["mul", "lstm"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = False + diff_threshold = 0.01 + infer_iterations = 100 + quant_iterations = 10 + onnx_format = True + self.run_test( + model_name, + model_url, + model_md5, + data_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index 74198da11fb2c..c219d2fbf89a9 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -116,7 +116,8 @@ def generate_quantized_model(self, is_use_cache_file=False, is_optimize_model=False, batch_size=10, - batch_nums=10): + batch_nums=10, + onnx_format=False): place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -134,6 +135,7 @@ def generate_quantized_model(self, round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, + onnx_format=onnx_format, is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model_path) @@ -151,7 +153,8 @@ def run_test(self, diff_threshold, batch_size=10, infer_iterations=10, - quant_iterations=5): + quant_iterations=5, + onnx_format=False): origin_model_path = self.download_model(data_url, data_md5, model_name) origin_model_path = os.path.join(origin_model_path, model_name) @@ -166,7 +169,7 @@ def run_test(self, self.generate_quantized_model(origin_model_path, algo, round_type, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, - batch_size, quant_iterations) + batch_size, quant_iterations, onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model_name, infer_iterations * batch_size)) @@ -335,5 +338,72 @@ def test_post_training_mse(self): infer_iterations, quant_iterations) +class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): + def test_post_training_mse_onnx_format(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "mse" + round_type = "round" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + onnx_format = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test( + model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) + + +class TestPostTrainingmseForMnistONNXFormatFullQuant( + TestPostTrainingQuantization): + def test_post_training_mse_onnx_format_full_quant(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "mse" + round_type = "round" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = True + is_use_cache_file = False + is_optimize_model = False + onnx_format = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test( + model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py b/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py new file mode 100644 index 0000000000000..99b00bc0c6897 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +from op_test import OpTest + + +def quantize_max_abs(x, max_range): + scale = np.max(np.abs(x).flatten()) + y = np.round(x / scale * max_range) + return y, scale + + +def dequantize_max_abs(x, scale, max_range): + y = (scale / max_range) * x + return y + + +def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." + scales = [] + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + if quant_axis == 0: + for i in range(x.shape[0]): + scale = np.max(np.abs(x[i])).astype("float32") + scales.append(scale) + y[i] = np.round(x[i] * max_range / scale) + elif quant_axis == 1: + for i in range(x.shape[1]): + scale = np.max(np.abs(x[:, i])).astype("float32") + scales.append(scale) + y[:, i] = np.round(x[:, i] * max_range / scale) + return y, scales + + +def channel_wise_dequantize_max_abs(x, + scales, + quant_bits, + quant_axis, + activation_scale=None): + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." + + if isinstance(quant_bits, list): + max_range = math.pow(2, quant_bits[0] - 1) - 1 + else: + max_range = math.pow(2, quant_bits - 1) - 1 + y = x.copy() + if quant_axis == 0: + for i in range(x.shape[0]): + y[i] = x[i] * scales[i] / max_range + elif quant_axis == 1: + for i in range(x.shape[1]): + y[:, i] = x[:, i] * scales[i] / max_range + + if activation_scale is not None: + y = y * activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) + return y + + +class TestChannelWiseDequantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 0 + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scale = channel_wise_quantize_max_abs(x, self.quant_bits[0], + self.quant_axis) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + self.quant_axis) + + self.inputs = { + 'X': yq, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': ydq} + + def test_check_output(self): + self.check_output() + + +class TestChannelWiseDequantizeOp1(TestChannelWiseDequantizeOp): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 1 + self.zero_point = 0. + + +class TestDequantizeOp(OpTest): + def set_args(self): + self.num_bits = 8 + self.quant_axis = -1 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + ydq = dequantize_max_abs(yq, scale, self.max_range) + + self.inputs = { + 'X': yq, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': ydq} + + def test_check_output(self): + self.check_output() + + +class TestDequantizeOpDouble(TestDequantizeOp): + def set_args(self): + self.num_bits = 8 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float64" + self.zero_point = 0. + self.quant_axis = -1 + + +class TestFakeDequantizeMaxAbsOp5Bits(TestDequantizeOp): + def set_args(self): + self.num_bits = 5 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" + self.zero_point = 0. + self.quant_axis = -1 + + +class TestChannelWisequantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 0 + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "quantize_linear" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scale = channel_wise_quantize_max_abs(x, self.quant_bits[0], + self.quant_axis) + + self.inputs = { + 'X': x, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': yq} + + def test_check_output(self): + self.check_output() + + +class TestChannelWisequantizeOp1(TestChannelWisequantizeOp): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 1 + self.zero_point = 0. + + +class TestquantizeOp(OpTest): + def set_args(self): + self.num_bits = 8 + self.quant_axis = -1 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + + self.inputs = { + 'X': x, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': yq} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()