From 8905820f45f88b5779e51a9d96cf13c3b8b1d8ce Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sun, 26 Dec 2021 09:58:44 +0000 Subject: [PATCH 1/4] add adaround post-quant method --- .../contrib/slim/quantization/adaround.py | 310 ++++++++++++++++++ .../post_training_quantization.py | 107 +++--- .../slim/quantization/quantization_pass.py | 73 ++--- .../fluid/contrib/slim/quantization/utils.py | 95 ++++++ ..._post_training_quantization_mobilenetv1.py | 61 +++- 5 files changed, 545 insertions(+), 101 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/quantization/adaround.py create mode 100644 python/paddle/fluid/contrib/slim/quantization/utils.py diff --git a/python/paddle/fluid/contrib/slim/quantization/adaround.py b/python/paddle/fluid/contrib/slim/quantization/adaround.py new file mode 100644 index 0000000000000..fa1f45ebeb6c0 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/adaround.py @@ -0,0 +1,310 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import time +import sys +import logging + +import paddle.fluid as fluid + +from ....log_helper import get_logger +from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + +GAMMA = -0.1 +ZETA = 1.1 + + +def compute_soft_rounding(alpha_v): + return fluid.layers.clip( + fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, min=0, max=1) + + +def compute_soft_rounding_np(alpha_v): + return np.clip( + stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1) + + +class AdaRoundLoss(object): + def __init__(self, reg_param=0.01, default_beta_range=(20, 2)): + self.default_reg_param = reg_param + self.default_beta_range = default_beta_range + + def compute_recon_loss(self, ada_quantized_output, orig_output): + square_cost = fluid.layers.square_error_cost(ada_quantized_output, + orig_output) + recon_loss = fluid.layers.reduce_mean( + fluid.layers.reduce_sum( + square_cost, dim=-1)) + return recon_loss + + def compute_round_loss(self, alpha_v, warm_start, beta): + def round_loss_fn(): + # compute rectified sigmoid of parameter 'alpha' which maps it between zero and one + h_v = compute_soft_rounding(alpha_v) + + # calculate regularization term - which ensures parameter to converge to exactly zeros and ones + # at the end of optimization + reg_term = fluid.layers.reduce_sum(-fluid.layers.pow( + fluid.layers.abs(2 * h_v - 1), factor=beta) + 1) + + # calculate the rounding loss + round_loss = self.default_reg_param * reg_term + + return round_loss + + round_loss = fluid.layers.cond(warm_start, lambda: fluid.layers.fill_constant(shape=[1], dtype='float32', value=0.0), round_loss_fn) + + return round_loss + + def compute_beta(self, max_iter, cur_iter, warm_start): + + # Start and stop beta for annealing of rounding loss (start_beta, end_beta) + start_beta, end_beta = self.default_beta_range + + # iteration at end of warm start period, which is 20% of max iterations + warm_start_end_iter = warm_start * max_iter + + # compute relative iteration of current iteration + rel_iter = (cur_iter - warm_start_end_iter) / ( + max_iter - warm_start_end_iter) + beta = end_beta + 0.5 * (start_beta - end_beta) * (1 + np.cos(rel_iter * + np.pi)) + + return beta + + +class AdaRound(object): + def __init__(self, + scale, + weight_tensor, + scope=None, + weight_var_name=None, + weight_op_type=None, + is_train=True, + num_iterations=1000): + self.is_train = is_train + self.num_iterations = num_iterations + self.warm_start = 0.1 + self.weight_bits = 8 + self.offset = 0. # zero-point offset + self.adaround_loss = AdaRoundLoss() + self.ori_weight_tensor = weight_tensor + self.scale = scale + self.scope = scope + self.quant_axis = 0 + if weight_op_type in _channelwise_quant_axis1_ops: + self.quant_axis = 1 + self.weight_var_name = weight_var_name + self.alpha_name = weight_var_name + ".alpha" + self.initialize_alpha(weight_tensor.copy(), scale, weight_var_name) + + def initialize_alpha(self, tensor, scale, var_name): + """ + Initializes alpha parameter, same shape as the weight tensor + """ + tensor_scale = quant_tensor(tensor, scale, quant_axis=self.quant_axis) + tensor_floor = np.floor(tensor_scale) + tensor = tensor_scale - tensor_floor + alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1) + self.alpha_v = fluid.layers.create_parameter( + shape=alpha.shape, + dtype="float32", + name=var_name + ".alpha", + default_initializer=fluid.initializer.NumpyArrayInitializer(alpha)) + + def _calculate_output_with_adarounded_weights(self, program, place, exe, + data, fp32_fetch_list, + weight_tensor_dequant): + set_variable_data(self.scope, place, self.weight_var_name, + weight_tensor_dequant) + + adaround_out_tensor = exe.run(program=program, + feed=data, + fetch_list=[fp32_fetch_list], + return_numpy=True, + scope=self.scope) + return adaround_out_tensor + + def _calculate_quant_weight(self): + np_alpha = load_variable_data(self.scope, self.alpha_name) + h_alpha = compute_soft_rounding_np(np_alpha) + + # Scale the tensor + tensor_scale = quant_tensor( + self.ori_weight_tensor.copy(), + self.scale, + quant_axis=self.quant_axis) + + weight_tensor = np.floor(tensor_scale) + + # Adaround the tensor + weight_tensor_quant = np.add(weight_tensor, h_alpha) + return weight_tensor_quant + + def _calculate_adarounded_weights(self): + weight_tensor_quant = self._calculate_quant_weight() + + # Dequantize the tensor + weight_tensor_dequant = dequant_tensor( + weight_tensor_quant + self.offset, + self.scale, + quant_axis=self.quant_axis) + return weight_tensor_dequant + + def update_final_weights(self): + weight_tensor_quant = self._calculate_quant_weight() + return weight_tensor_quant + + def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor): + round_loss = self.adaround_loss.compute_round_loss(self.alpha_v, + warm_start, beta) + recon_loss = self.adaround_loss.compute_recon_loss(adaround_out_tensor, + orig_out_tensor) + loss = round_loss + recon_loss + losses = { + 'loss': loss, + 'round_loss': round_loss, + 'recon_loss': recon_loss + } + return losses + + def update_beta_warm(self, cur_iteration): + warm_start = cur_iteration < self.num_iterations * self.warm_start + beta = self.adaround_loss.compute_beta(self.num_iterations, + cur_iteration, self.warm_start) + return beta, warm_start + + +def run_adaround(data_loader, + fp32_program, + fetch_list, + exe, + scope, + place, + quantized_op_output_name_dict, + weight_op_pairs, + scale_dict, + num_iterations=1000, + lr=0.001, + fast_mode=True): + fetch_op_name = fetch_list[0].name + final_weight_tensor_quant_dict = {} + for weight_var_name, quant_op_out_name in quantized_op_output_name_dict.items( + ): + _logger.info('Start adaround op: {}'.format(weight_var_name)) + weight_op_type = weight_op_pairs[weight_var_name] + # get scale and weight tensor + weight_var_tensor = load_variable_data(scope, weight_var_name) + scale = scale_dict[weight_var_name] + fp32_fetch_list = None + for _op in fp32_program.global_block().ops: + if _op.type == "fetch": + _op._rename_input(fetch_op_name, quant_op_out_name) + fp32_fetch_list = fp32_program.global_block().var( + quant_op_out_name) + fetch_op_name = quant_op_out_name + + # build adaround program + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_iteration_per_drop_scope = 1 + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + # initialize adaround + adaround = AdaRound( + scale, + weight_var_tensor, + scope=scope, + weight_var_name=weight_var_name, + weight_op_type=weight_op_type, + num_iterations=num_iterations) + orig_out_tensor = fluid.data( + name='orig_out_tensor', + shape=fp32_fetch_list.shape, + dtype='float32') + adaround_out_tensor = fluid.data( + name='adaround_out_tensor', + shape=fp32_fetch_list.shape, + dtype='float32') + beta_tensor = fluid.data( + name='beta', shape=[1], dtype='float32') + warm_start_tensor = fluid.data( + name='warm_start', shape=[1], dtype='bool') + + train_fetches_loss = adaround.get_loss( + beta_tensor, warm_start_tensor, adaround_out_tensor, + orig_out_tensor) + optimizer = fluid.optimizer.Adam(learning_rate=lr) + loss = train_fetches_loss['loss'] + optimizer.minimize(loss) + exe.run(startup_program) + + start_time = time.time() + prev_start_time = start_time + for i, data in enumerate(data_loader()): + prev_start_time = start_time + start_time = time.time() + # run fp32 model + np_orig_out_tensor = exe.run(program=fp32_program, + feed=data, + fetch_list=[fp32_fetch_list], + return_numpy=True, + scope=scope) + + adaround_weight_tensor_dequant = adaround._calculate_adarounded_weights( + ) + np_adaround_out_tensor = adaround._calculate_output_with_adarounded_weights( + fp32_program, place, exe, data, fp32_fetch_list, + adaround_weight_tensor_dequant) + + # If the cosine distance of the two tensor is small, skip training + cos_error = calculate_quant_cos_error(np_orig_out_tensor[0], + np_adaround_out_tensor[0]) + if fast_mode and cos_error > 0.99: + _logger.info("The cosine error is small, skip training.") + break + beta, warm_start = adaround.update_beta_warm(i) + feed_dict = { + 'orig_out_tensor': np_orig_out_tensor[0], + 'adaround_out_tensor': np_adaround_out_tensor[0], + 'beta': beta, + 'warm_start': warm_start + } + out = exe.run( + train_program, + feed=feed_dict, + fetch_list=[v.name for v in train_fetches_loss.values()], + return_numpy=True) + _logger.info( + "Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s". + format(i, lr, + np.mean(out[0]), + np.mean(out[1]), + np.mean(out[2]), start_time - prev_start_time)) + sys.stdout.flush() + if i == num_iterations: + break + final_weight_tensor_quant_dict[ + weight_var_name] = adaround.update_final_weights() + del adaround + + # update adarounded calibrated weights + for weight_var_name in quantized_op_output_name_dict.keys(): + set_variable_data(scope, place, weight_var_name, + final_weight_tensor_quant_dict[weight_var_name]) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index e9173a86b89fa..555806f6b2e4f 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -34,6 +34,8 @@ from .quantization_pass import _get_input_name_index from .quantization_pass import _channelwise_quant_axis1_ops from .cal_kl_threshold import cal_kl_threshold +from .adaround import run_adaround +from .utils import load_variable_data, set_variable_data __all__ = ['PostTrainingQuantization', 'WeightQuantization'] @@ -41,28 +43,6 @@ __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -def _load_variable_data(scope, var_name): - ''' - Load variable value from scope - ''' - var_node = scope.find_var(var_name) - assert var_node is not None, \ - "Cannot find " + var_name + " in scope." - return np.array(var_node.get_tensor()) - - -def _set_variable_data(scope, place, var_name, np_value): - ''' - Set the value of var node by name, if the node exits, - ''' - assert isinstance(np_value, np.ndarray), \ - 'The type of value should be numpy array.' - var_node = scope.find_var(var_name) - if var_node != None: - tensor = var_node.get_tensor() - tensor.set(np_value, place) - - def _all_persistable_var_names(program): persistable_var_names = [] for var in program.list_vars(): @@ -141,6 +121,9 @@ def __init__(self, algo="KL", hist_percent=0.99999, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + round_type='round', + train_iterations=1000, + learning_rate=0.001, is_full_quantize=False, bias_correction=False, activation_bits=8, @@ -193,6 +176,12 @@ def __init__(self, quantizable_op_type(list[str], optional): List the type of ops that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. + round_type(str, optional): The method of converting the quantized weights + value from float to int. Currently supports ['round', 'adaround'] methods. + Default is `round`, which is rounding nearest to the nearest whole number. + train_iterations(flota, optional): The number of training iter, used to + calibrate the adaptive rounding method, when round_type='adaround'. + learning_rate(flota, optional): The learning rate of adaround method. is_full_quantized(bool, optional): If set is_full_quantized as True, apply quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type @@ -269,6 +258,10 @@ def __init__(self, self._support_algo_type = [ 'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max' ] + assert round_type in ['adaround', 'round'] + self._round_type = round_type + self._train_iterations = train_iterations + self._learning_rate = learning_rate self._dynamic_quantize_op_type = ['lstm'] self._support_quantize_op_type = \ list(set(QuantizationTransformPass._supported_quantizable_op_type + @@ -393,6 +386,10 @@ def quantize(self): if self._batch_nums and batch_id >= self._batch_nums: break _logger.info("Finish sampling stage, all batch: " + str(batch_id)) + + if self._round_type == 'adaround': + self._adaround_apply() + self._reset_activation_persistable() if self._algo == 'avg': for var_name in self._quantized_act_var_name: @@ -429,6 +426,24 @@ def quantize(self): return self._program + def _adaround_apply(self): + if self._algo in ["KL", "hist"]: + scale_dict = self._quantized_var_threshold + else: + scale_dict = self._quantized_threshold + run_adaround( + self._data_loader, + self._program, + self._fetch_list, + self._executor, + self._scope, + self._place, + self._quantized_op_output_name_dict, + self._weight_op_pairs, + scale_dict, + num_iterations=self._train_iterations, + lr=self._learning_rate) + def save_quantized_model(self, save_model_path, model_filename=None, @@ -508,6 +523,7 @@ def _collect_target_varnames(self): ''' # TODO(juncaipeng), consider the name_scope of skip_quant _logger.info("Collect quantized variable names ...") + self._quantized_op_output_name_dict = {} def collect_var_name(var_name_list, persistable_var_names, op_type): for var_name in var_name_list: @@ -533,6 +549,12 @@ def collect_var_name(var_name_list, persistable_var_names, op_type): collect_var_name( _get_op_output_var_names(op), persistable_var_names, op_type) + # collect quanted op output var name + for out_var_name in _get_op_output_var_names(op): + for in_var_name in _get_op_input_var_names(op): + if in_var_name in persistable_var_names: + self._quantized_op_output_name_dict[ + in_var_name] = out_var_name # For other op, only sample output scale elif op_type in self._out_scale_op_list: collect_var_name( @@ -577,7 +599,7 @@ def _sampling(self): def _sample_mse(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": @@ -594,7 +616,7 @@ def _sample_mse(self): self._quantized_threshold[var_name] = abs_max_value _logger.info("MSE searching stage ...") for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value @@ -616,7 +638,7 @@ def _sample_mse(self): def _sample_avg(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": @@ -633,7 +655,7 @@ def _sample_avg(self): self._quantized_threshold[var_name] = abs_max_value for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) abs_max_value = float(np.max(np.abs(var_tensor))) if (var_name not in self._quantized_var_avg): self._quantized_var_avg[var_name] = [] @@ -645,7 +667,7 @@ def _sample_avg(self): def _sample_abs_max(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": @@ -662,7 +684,7 @@ def _sample_abs_max(self): self._quantized_threshold[var_name] = abs_max_value for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) abs_max_value = float(np.max(np.abs(var_tensor))) if (var_name not in self._quantized_threshold) or \ (abs_max_value > self._quantized_threshold[var_name]): @@ -671,7 +693,7 @@ def _sample_abs_max(self): def _sample_min_max(self): if self._quantized_var_min == {} and self._quantized_var_max == {}: for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) @@ -691,7 +713,7 @@ def _sample_min_max(self): self._quantized_var_max[var_name] = max_value for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) if (var_name not in self._quantized_var_min) or \ @@ -703,7 +725,7 @@ def _sample_min_max(self): def _sample_histogram(self): for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) var_tensor_abs = np.abs(var_tensor) bins = self._sampling_act_histogram[var_name][1] hist, _ = np.histogram(var_tensor_abs, bins=bins) @@ -733,7 +755,7 @@ def _collect_activation_abs_min_max(self): get the min and max value, and then calculate the threshold. ''' for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) var_tensor = np.abs(var_tensor) min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) @@ -767,7 +789,7 @@ def _calculate_kl_hist_threshold(self): # Abs_max threshold for weights for var_name in self._quantized_weight_var_name: - weight_data = _load_variable_data(self._scope, var_name) + weight_data = load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": weight_threshold = float(np.max(np.abs(weight_data))) elif self._weight_quantize_type == "channel_wise_abs_max": @@ -842,13 +864,13 @@ def _update_program(self): else: scale_dict = self._quantized_threshold for key, val in scale_dict.items(): - _set_variable_data( + set_variable_data( self._scope, self._place, key + ".scale", np.array( [val], dtype=np.float32)) - _set_variable_data( + set_variable_data( self._scope, self._place, key + ".quant_dequant.scale", @@ -861,6 +883,7 @@ def _update_program(self): place=self._place, bias_correction=self._bias_correction, weight_bits=self._weight_bits, + round_type=self._round_type, activation_bits=self._activation_bits, weight_quantize_type=self._weight_quantize_type, quantizable_op_type=major_quantizable_op_types) @@ -951,7 +974,7 @@ def _collect_dynamic_quantize_op_threshold(self, target_ops_type): for op in target_ops: for var_name in _get_op_input_var_names(op): if var_name in persistable_var_names: - var_data = _load_variable_data(self._scope, var_name) + var_data = load_variable_data(self._scope, var_name) threshold = float(np.max(np.abs(var_data))) argname, index = _get_input_name_index(op, var_name) op._set_attr(argname + str(index) + "_threshold", threshold) @@ -1197,7 +1220,7 @@ def _weight_abs_max_quantization(self, scope, place, weight_bits, save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 # Get quantized scale and weight data - weight_data = _load_variable_data(scope, var_name) + weight_data = load_variable_data(scope, var_name) if abs(threshold_rate) < 1e-10: threshold_value = np.max(np.abs(weight_data)) else: @@ -1211,11 +1234,11 @@ def _weight_abs_max_quantization(self, scope, place, weight_bits, # Set weight data if not for_test: - _set_variable_data(scope, place, var_name, quantized_weight_data) + set_variable_data(scope, place, var_name, quantized_weight_data) else: dequantized_weight_data = \ (quantized_weight_data * scale).astype(np.float32) - _set_variable_data(scope, place, var_name, dequantized_weight_data) + set_variable_data(scope, place, var_name, dequantized_weight_data) # Save info op._set_attr('quantization_type', 'post_weight_abs_max') @@ -1232,7 +1255,7 @@ def _weight_channel_wise_abs_max_quantization( save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 # Get quantized scale and weight data - weight_data = _load_variable_data(scope, var_name) + weight_data = load_variable_data(scope, var_name) if op.type == "mul": scales, quantized_weight_data = \ self._mul_channel_wise_quantization(weight_data, @@ -1246,7 +1269,7 @@ def _weight_channel_wise_abs_max_quantization( # Set weight data if not for_test: - _set_variable_data(scope, place, var_name, quantized_weight_data) + set_variable_data(scope, place, var_name, quantized_weight_data) else: if op.type == "mul": dequantized_weight_data = \ @@ -1257,7 +1280,7 @@ def _weight_channel_wise_abs_max_quantization( else: _logger.error(op.type + " is not supported by weight quantization") - _set_variable_data(scope, place, var_name, dequantized_weight_data) + set_variable_data(scope, place, var_name, dequantized_weight_data) # Save info op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max') diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 645feda21f0f3..1abade7847dfa 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -26,6 +26,7 @@ from ....layers import mean from ....executor import scope_guard from ....framework import _get_paddle_place +from .utils import _channelwise_quant_axis1_ops, quant_tensor __all__ = [ 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', @@ -141,10 +142,6 @@ _conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] -_channelwise_quant_axis1_ops = [ - 'conv2d_transpose', 'mul', 'matmul', 'matmul_v2' -] - def _get_op_input_var_names(op): """ @@ -1114,6 +1111,7 @@ def __init__(self, bias_correction=False, weight_bits=8, activation_bits=8, + round_type='round', weight_quantize_type='abs_max', quantizable_op_type=None): """ @@ -1131,6 +1129,9 @@ def __init__(self, https://arxiv.org/abs/1810.05723. weight_bits(int): quantization bit number for weights. activation_bits(int): quantization bit number for activation. + round_type(str, optional): The method of converting the quantized weights + value from float to int. Currently supports ['round', 'adaround'] methods. + Default is `round`, which is rounding nearest to the nearest whole number. weight_quantize_type(str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. @@ -1146,6 +1147,7 @@ def __init__(self, self._place = _get_paddle_place(place) self._weight_bits = weight_bits self._activation_bits = activation_bits + self._round_type = round_type self._weight_quantize_type = weight_quantize_type self._fake_quant_op_names = _fake_quant_op_list self._fake_dequant_op_names = _fake_dequant_op_list @@ -1192,18 +1194,22 @@ def apply(self, graph): self._quant_var_scale_map[input_arg_name] = scale_v # Quantize weight and restore param_v = self._load_var(input_arg_name) - if isinstance(scale_v, list) and \ - any(_check_grandchild_op_node(op_node, op) - for op in _channelwise_quant_axis1_ops): - quant_axis = 1 - else: - quant_axis = 0 - quantized_param_v = self._quant( - param_v.copy(), scale_v, self._weight_bits, quant_axis) - if self._bias_correction == True: - quantized_param_v = self._bias_correction_w( - param_v, quantized_param_v, scale_v, quant_axis) - self._restore_var(input_arg_name, quantized_param_v) + if self._round_type == 'round': + if any( + _check_grandchild_op_node(op_node, op) + for op in _channelwise_quant_axis1_ops): + quant_axis = 1 + else: + quant_axis = 0 + quantized_param_v = quant_tensor(param_v.copy(), + scale_v, quant_axis, + self._weight_bits) + quantized_param_v = np.round(quantized_param_v) + if self._bias_correction == True: + quantized_param_v = self._bias_correction_w( + param_v, quantized_param_v, scale_v, quant_axis) + quantized_param_v = np.round(quantized_param_v) + self._restore_var(input_arg_name, quantized_param_v) self._remove_fake_quant_and_dequant_op(graph, op_node) # Remove all fake dequant op @@ -1220,6 +1226,12 @@ def apply(self, graph): if op_node_desc.has_attr("quantization_type") and \ op_node_desc.attr("quantization_type") == "qat_with_weight": if self._weight_quantize_type == 'channel_wise_abs_max': + if any( + _check_grandchild_op_node(op_node, op) + for op in _channelwise_quant_axis1_ops): + quant_axis = 1 + else: + quant_axis = 0 self._insert_post_channel_dequant_op(graph, op_node, quant_axis) else: @@ -1419,31 +1431,6 @@ def _is_float(self, v): return isinstance(v, float) or isinstance(v, np.float32) \ or isinstance(v, np.float64) - def _quant(self, x, scale, num_bits, quant_axis): - assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' - bnt = (1 << (num_bits - 1)) - 1 - - def _clip(x, scale): - x[x > scale] = scale - x[x < -scale] = -scale - return x - - if isinstance(scale, list): - for i, s in enumerate(scale): - if s == 0.0: - s = 1e-8 - if quant_axis == 0: - x[i] = _clip(x[i], s) - x[i] = np.round(x[i] / s * bnt) - else: - x[:, i] = _clip(x[:, i], s) - x[:, i] = np.round(x[:, i] / s * bnt) - else: - scale = 1e-8 if scale == 0.0 else scale - x = _clip(x, scale) - x = np.round(x / scale * bnt) - return x - def _bias_correction_w(self, x, x_quant, scale_v, quant_axis): ''' Bias correction for weight @@ -1480,8 +1467,8 @@ def _bias_correction_w(self, x, x_quant, scale_v, quant_axis): mean_bias = np.resize(mean_bias, x.shape) x_dequant = (mean_bias + x_dequant) * std_bias - quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits, - quant_axis) + quantized_param_v = quant_tensor(x_dequant, scale_v, quant_axis, + self._weight_bits) return quantized_param_v diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py new file mode 100644 index 0000000000000..43f33f33c3138 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +_channelwise_quant_axis1_ops = [ + 'conv2d_transpose', 'mul', 'matmul', 'matmul_v2' +] + + +def load_variable_data(scope, var_name): + ''' + Load variable value from scope + ''' + var_node = scope.find_var(var_name) + assert var_node is not None, \ + "Cannot find " + var_name + " in scope." + return np.array(var_node.get_tensor()) + + +def set_variable_data(scope, place, var_name, np_value): + ''' + Set the value of var node by name, if the node exits, + ''' + assert isinstance(np_value, np.ndarray), \ + 'The type of value should be numpy array.' + var_node = scope.find_var(var_name) + if var_node != None: + tensor = var_node.get_tensor() + tensor.set(np_value, place) + + +def quant_tensor(x, scale, quant_axis=0, weight_bits=8): + # symmetry quant + def _clip(x, scale): + x[x > scale] = scale + x[x < -scale] = -scale + return x + + assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' + bnt = (1 << (weight_bits - 1)) - 1 + if isinstance(scale, list): + for i, s in enumerate(scale): + if s == 0.0: + s = 1e-8 + if quant_axis == 0: + x[i] = _clip(x[i], s) + x[i] = x[i] / s * bnt + else: + x[:, i] = _clip(x[:, i], s) + x[:, i] = x[:, i] / s * bnt + else: + scale = 1e-8 if scale == 0.0 else scale + x = _clip(x, scale) + x = x / scale * bnt + return x + + +def dequant_tensor(x, scale, quant_axis=0, weight_bits=8): + assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' + bnt = (1 << (weight_bits - 1)) - 1 + if isinstance(scale, list): + for i, s in enumerate(scale): + if s == 0.0: + s = 1e-8 + if quant_axis == 0: + x[i] = x[i] * s / bnt + else: + x[:, i] = x[:, i] * s / bnt + else: + scale = 1e-8 if scale == 0.0 else scale + x = x * scale / bnt + return x + + +def stable_sigmoid(x): + sig = np.where(x < 0, np.exp(x) / (1 + np.exp(x)), 1 / (1 + np.exp(-x))) + return sig + + +def calculate_quant_cos_error(orig_tensor, qdq_tensor): + cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \ + / (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten())) + return cos_sim diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 7161104861006..7790fea857590 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -240,6 +240,7 @@ def generate_quantized_model(self, model_path, quantizable_op_type, algo="KL", + round_type="round", is_full_quantize=False, is_use_cache_file=False, is_optimize_model=False): @@ -261,15 +262,16 @@ def generate_quantized_model(self, model_dir=model_path, algo=algo, quantizable_op_type=quantizable_op_type, + round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model) - def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold): + def run_test(self, model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold): infer_iterations = self.infer_iterations batch_size = self.batch_size sample_iterations = self.sample_iterations @@ -285,7 +287,7 @@ def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type, format(model, sample_iterations * batch_size)) self.generate_quantized_model( model_cache_folder + "/model", quantizable_op_type, algo, - is_full_quantize, is_use_cache_file, is_optimize_model) + round_type, is_full_quantize, is_use_cache_file, is_optimize_model) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) @@ -309,6 +311,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): def test_post_training_kl_mobilenetv1(self): model = "MobileNet-V1" algo = "KL" + round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -323,15 +326,16 @@ def test_post_training_kl_mobilenetv1(self): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 - self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold) + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): def test_post_training_avg_mobilenetv1(self): model = "MobileNet-V1" algo = "avg" + round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -345,15 +349,16 @@ def test_post_training_avg_mobilenetv1(self): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 - self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold) + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): def test_post_training_hist_mobilenetv1(self): model = "MobileNet-V1" algo = "hist" + round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -367,15 +372,16 @@ def test_post_training_hist_mobilenetv1(self): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 - self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold) + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): def test_post_training_abs_max_mobilenetv1(self): model = "MobileNet-V1" algo = "abs_max" + round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -389,9 +395,32 @@ def test_post_training_abs_max_mobilenetv1(self): is_optimize_model = False # The accuracy diff of post-traing quantization (abs_max) maybe bigger diff_threshold = 0.05 - self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold) + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) + + +class TestPostTrainingAdaRoundForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_adaround_mobilenetv1(self): + model = "MobileNet-V1" + algo = "avg" + round_type = "adaround" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.025 + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) if __name__ == '__main__': From 7e4fea843a095a8805d7a85fc96c636b7369a49f Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Wed, 2 Mar 2022 04:34:27 +0000 Subject: [PATCH 2/4] add unit test --- ...t_post_training_quantization_lstm_model.py | 44 ++++++++--- .../test_post_training_quantization_mnist.py | 76 +++++++++++++------ ..._post_training_quantization_mobilenetv1.py | 73 +++++++++++++++++- ...est_post_training_quantization_resnet50.py | 7 +- 4 files changed, 164 insertions(+), 36 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py index 8a28ee7983e6a..58a430eb96406 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py @@ -167,6 +167,7 @@ def generate_quantized_model(self, model_path, data_path, algo="KL", + round_type="round", quantizable_op_type=["conv2d"], is_full_quantize=False, is_use_cache_file=False, @@ -186,6 +187,7 @@ def generate_quantized_model(self, batch_nums=batch_nums, algo=algo, quantizable_op_type=quantizable_op_type, + round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, is_use_cache_file=is_use_cache_file) @@ -193,9 +195,9 @@ def generate_quantized_model(self, ptq.save_quantized_model(self.int8_model_path) def run_test(self, model_name, model_url, model_md5, data_name, data_url, - data_md5, algo, quantizable_op_type, is_full_quantize, - is_use_cache_file, is_optimize_model, diff_threshold, - infer_iterations, quant_iterations): + data_md5, algo, round_type, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, infer_iterations, quant_iterations): fp32_model_path = self.download_model(model_url, model_md5, model_name) fp32_model_path = os.path.join(fp32_model_path, model_name) @@ -210,9 +212,9 @@ def run_test(self, model_name, model_url, model_md5, data_name, data_url, print("Start post training quantization for {0} on {1} samples ...". format(model_name, quant_iterations)) self.generate_quantized_model(fp32_model_path, data_path, algo, - quantizable_op_type, is_full_quantize, - is_use_cache_file, is_optimize_model, - quant_iterations) + round_type, quantizable_op_type, + is_full_quantize, is_use_cache_file, + is_optimize_model, quant_iterations) print("Start INT8 inference for {0} on {1} samples ...".format( model_name, infer_iterations)) @@ -239,6 +241,7 @@ def test_post_training_kl(self): data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" data_md5 = "add84c754e9b792fea1fbd728d134ab7" algo = "KL" + round_type = "round" quantizable_op_type = ["mul", "lstm"] is_full_quantize = False is_use_cache_file = False @@ -247,9 +250,32 @@ def test_post_training_kl(self): infer_iterations = 100 quant_iterations = 10 self.run_test(model_name, model_url, model_md5, data_name, data_url, - data_md5, algo, quantizable_op_type, is_full_quantize, - is_use_cache_file, is_optimize_model, diff_threshold, - infer_iterations, quant_iterations) + data_md5, algo, round_type, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, infer_iterations, quant_iterations) + + +class TestPostTrainingKLForMnistAdaround(TestPostTrainingQuantization): + def test_post_training_kl(self): + model_name = "nlp_lstm_fp32_model" + model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" + model_md5 = "519b8eeac756e7b4b7bcb2868e880452" + data_name = "quant_lstm_input_data" + data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" + data_md5 = "add84c754e9b792fea1fbd728d134ab7" + algo = "KL" + round_type = "adaround" + quantizable_op_type = ["mul", "lstm"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = False + diff_threshold = 0.01 + infer_iterations = 100 + quant_iterations = 10 + self.run_test(model_name, model_url, model_md5, data_name, data_url, + data_md5, algo, round_type, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, infer_iterations, quant_iterations) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index da5c5d6dc9441..1ed0dcb859f05 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -110,6 +110,7 @@ def run_program(self, model_path, batch_size, infer_iterations): def generate_quantized_model(self, model_path, algo="KL", + round_type="round", quantizable_op_type=["conv2d"], is_full_quantize=False, is_use_cache_file=False, @@ -130,6 +131,7 @@ def generate_quantized_model(self, batch_nums=batch_nums, algo=algo, quantizable_op_type=quantizable_op_type, + round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, is_use_cache_file=is_use_cache_file) @@ -141,6 +143,7 @@ def run_test(self, data_url, data_md5, algo, + round_type, quantizable_op_type, is_full_quantize, is_use_cache_file, @@ -160,9 +163,10 @@ def run_test(self, print("Start INT8 post training quantization for {0} on {1} images ...". format(model_name, quant_iterations * batch_size)) - self.generate_quantized_model( - origin_model_path, algo, quantizable_op_type, is_full_quantize, - is_use_cache_file, is_optimize_model, batch_size, quant_iterations) + self.generate_quantized_model(origin_model_path, algo, round_type, + quantizable_op_type, is_full_quantize, + is_use_cache_file, is_optimize_model, + batch_size, quant_iterations) print("Start INT8 inference for {0} on {1} images ...".format( model_name, infer_iterations * batch_size)) @@ -190,6 +194,7 @@ def test_post_training_kl(self): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "KL" + round_type = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -198,10 +203,10 @@ def test_post_training_kl(self): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, batch_size, infer_iterations, - quant_iterations) + self.run_test(model_name, data_url, data_md5, algo, round_type, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold, batch_size, + infer_iterations, quant_iterations) class TestPostTraininghistForMnist(TestPostTrainingQuantization): @@ -210,6 +215,7 @@ def test_post_training_hist(self): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "hist" + round_type = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -218,10 +224,10 @@ def test_post_training_hist(self): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, batch_size, infer_iterations, - quant_iterations) + self.run_test(model_name, data_url, data_md5, algo, round_type, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold, batch_size, + infer_iterations, quant_iterations) class TestPostTrainingmseForMnist(TestPostTrainingQuantization): @@ -230,6 +236,7 @@ def test_post_training_mse(self): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "mse" + round_type = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -238,10 +245,10 @@ def test_post_training_mse(self): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, batch_size, infer_iterations, - quant_iterations) + self.run_test(model_name, data_url, data_md5, algo, round_type, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold, batch_size, + infer_iterations, quant_iterations) class TestPostTrainingavgForMnist(TestPostTrainingQuantization): @@ -250,6 +257,7 @@ def test_post_training_avg(self): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "avg" + round_type = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -258,10 +266,10 @@ def test_post_training_avg(self): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, batch_size, infer_iterations, - quant_iterations) + self.run_test(model_name, data_url, data_md5, algo, round_type, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold, batch_size, + infer_iterations, quant_iterations) class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): @@ -270,6 +278,7 @@ def test_post_training_abs_max(self): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "abs_max" + round_type = "round" quantizable_op_type = ["conv2d", "mul"] is_full_quantize = True is_use_cache_file = False @@ -278,10 +287,31 @@ def test_post_training_abs_max(self): batch_size = 10 infer_iterations = 50 quant_iterations = 10 - self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, batch_size, infer_iterations, - quant_iterations) + self.run_test(model_name, data_url, data_md5, algo, round_type, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold, batch_size, + infer_iterations, quant_iterations) + + +class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): + def test_post_training_mse(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "mse" + round_type = "adaround" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, round_type, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold, batch_size, + infer_iterations, quant_iterations) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 7790fea857590..8ed83894689b3 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -400,7 +400,7 @@ def test_post_training_abs_max_mobilenetv1(self): is_optimize_model, diff_threshold) -class TestPostTrainingAdaRoundForMobilenetv1(TestPostTrainingQuantization): +class TestPostTrainingAvgAdaRoundForMobilenetv1(TestPostTrainingQuantization): def test_post_training_adaround_mobilenetv1(self): model = "MobileNet-V1" algo = "avg" @@ -423,5 +423,76 @@ def test_post_training_adaround_mobilenetv1(self): is_optimize_model, diff_threshold) +class TestPostTrainingAbsMaxAdaRoundForMobilenetv1( + TestPostTrainingQuantization): + def test_post_training_adaround_mobilenetv1(self): + model = "MobileNet-V1" + algo = "abs_max" + round_type = "adaround" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.025 + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) + + +class TestPostTraininghistAdaroundForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_hist_mobilenetv1(self): + model = "MobileNet-V1" + algo = "hist" + round_type = "adaround" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.025 + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) + + +class TestPostTrainingKLAdaroundForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_kl_mobilenetv1(self): + model = "MobileNet-V1" + algo = "KL" + round_type = "adaround" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + "pool2d", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.025 + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py index 12b5a2458a4da..a26dcb51c724a 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py @@ -24,6 +24,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): def test_post_training_resnet50(self): model = "ResNet-50" algo = "min_max" + round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' ] @@ -33,9 +34,9 @@ def test_post_training_resnet50(self): is_use_cache_file = False is_optimize_model = False diff_threshold = 0.025 - self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold) + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) if __name__ == '__main__': From 08572030baa52ebe56991963ef73f05d650d24a1 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Thu, 24 Mar 2022 08:29:33 +0000 Subject: [PATCH 3/4] fix some parameter --- .../contrib/slim/quantization/adaround.py | 7 +++---- .../quantization/post_training_quantization.py | 18 +++++++----------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/adaround.py b/python/paddle/fluid/contrib/slim/quantization/adaround.py index fa1f45ebeb6c0..f6908d7e836a7 100644 --- a/python/paddle/fluid/contrib/slim/quantization/adaround.py +++ b/python/paddle/fluid/contrib/slim/quantization/adaround.py @@ -196,7 +196,7 @@ def run_adaround(data_loader, exe, scope, place, - quantized_op_output_name_dict, + quantized_op_pairs, weight_op_pairs, scale_dict, num_iterations=1000, @@ -204,8 +204,7 @@ def run_adaround(data_loader, fast_mode=True): fetch_op_name = fetch_list[0].name final_weight_tensor_quant_dict = {} - for weight_var_name, quant_op_out_name in quantized_op_output_name_dict.items( - ): + for weight_var_name, quant_op_out_name in quantized_op_pairs.items(): _logger.info('Start adaround op: {}'.format(weight_var_name)) weight_op_type = weight_op_pairs[weight_var_name] # get scale and weight tensor @@ -305,6 +304,6 @@ def run_adaround(data_loader, del adaround # update adarounded calibrated weights - for weight_var_name in quantized_op_output_name_dict.keys(): + for weight_var_name in quantized_op_pairs.keys(): set_variable_data(scope, place, weight_var_name, final_weight_tensor_quant_dict[weight_var_name]) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index e5b4903312c05..65e13c5127773 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -124,7 +124,6 @@ def __init__(self, hist_percent=0.99999, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], round_type='round', - train_iterations=1000, learning_rate=0.001, is_full_quantize=False, bias_correction=False, @@ -182,11 +181,9 @@ def __init__(self, that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. round_type(str, optional): The method of converting the quantized weights - value from float to int. Currently supports ['round', 'adaround'] methods. + value float->int. Currently supports ['round', 'adaround'] methods. Default is `round`, which is rounding nearest to the nearest whole number. - train_iterations(flota, optional): The number of training iter, used to - calibrate the adaptive rounding method, when round_type='adaround'. - learning_rate(flota, optional): The learning rate of adaround method. + learning_rate(float, optional): The learning rate of adaround method. is_full_quantized(bool, optional): If set is_full_quantized as True, apply quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type @@ -265,7 +262,6 @@ def __init__(self, ] assert round_type in ['adaround', 'round'] self._round_type = round_type - self._train_iterations = train_iterations self._learning_rate = learning_rate self._dynamic_quantize_op_type = ['lstm'] self._support_quantize_op_type = \ @@ -446,10 +442,10 @@ def _adaround_apply(self): self._executor, self._scope, self._place, - self._quantized_op_output_name_dict, + self._quantized_op_pairs, self._weight_op_pairs, scale_dict, - num_iterations=self._train_iterations, + num_iterations=self._batch_nums, lr=self._learning_rate) def save_quantized_model(self, @@ -534,7 +530,7 @@ def _collect_target_varnames(self): ''' # TODO(juncaipeng), consider the name_scope of skip_quant _logger.info("Collect quantized variable names ...") - self._quantized_op_output_name_dict = {} + self._quantized_op_pairs = {} def collect_var_name(var_name_list, persistable_var_names, op_type): for var_name in var_name_list: @@ -564,7 +560,7 @@ def collect_var_name(var_name_list, persistable_var_names, op_type): for out_var_name in _get_op_output_var_names(op): for in_var_name in _get_op_input_var_names(op): if in_var_name in persistable_var_names: - self._quantized_op_output_name_dict[ + self._quantized_op_pairs[ in_var_name] = out_var_name # For other op, only sample output scale elif op_type in self._out_scale_op_list: @@ -984,7 +980,7 @@ def analysis_and_save_info(op_node, out_var_name): argname_index[0] + str(argname_index[1]) + "_threshold", "post_hist") - elif self._algo in ["avg", "abs_max", "mse"]: + elif self._algo in ["avg", "abs_max", "mse", "emd"]: save_info(op_node, out_var_name, self._quantized_threshold, "out_threshold", "post_" + str(self._algo)) save_info( From d25666807995d410090840935c37c446a70b1d65 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Fri, 25 Mar 2022 02:10:13 +0000 Subject: [PATCH 4/4] fix unittest --- .../slim/quantization/post_training_quantization.py | 4 ++-- .../slim/tests/test_post_training_quantization_mnist.py | 9 +++++---- .../tests/test_post_training_quantization_mobilenetv1.py | 7 ++++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 65e13c5127773..b1b645e85e75d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -647,7 +647,7 @@ def _sample_mse(self): def _sample_emd(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": @@ -664,7 +664,7 @@ def _sample_emd(self): self._quantized_threshold[var_name] = abs_max_value _logger.info("EMD searching stage ...") for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index c814ef539111d..74198da11fb2c 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -257,6 +257,7 @@ def test_post_training_mse(self): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "emd" + round_type = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -265,10 +266,10 @@ def test_post_training_mse(self): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, batch_size, infer_iterations, - quant_iterations) + self.run_test(model_name, data_url, data_md5, algo, round_type, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold, batch_size, + infer_iterations, quant_iterations) class TestPostTrainingavgForMnist(TestPostTrainingQuantization): diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 038c05f5e6112..312a0c9e4b40e 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -498,6 +498,7 @@ class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization): def test_post_training_avg_mobilenetv1(self): model = "MobileNet-V1" algo = "emd" + round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -511,9 +512,9 @@ def test_post_training_avg_mobilenetv1(self): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 - self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold) + self.run_test(model, algo, round_type, data_urls, data_md5s, + quantizable_op_type, is_full_quantize, is_use_cache_file, + is_optimize_model, diff_threshold) if __name__ == '__main__':