Skip to content

Commit

Permalink
add new function ptq first then initialize qat scale with ptq scale (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangHandi authored Aug 25, 2022
1 parent bdd0b0f commit 9ac27ac
Show file tree
Hide file tree
Showing 8 changed files with 560 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import os
import re
import math
import shutil
import logging
import numpy as np
import shutil

try:
from tqdm import tqdm
except:
Expand All @@ -34,7 +36,10 @@
from .adaround import run_adaround
from . import utils

__all__ = ['PostTrainingQuantization', 'WeightQuantization']
__all__ = [
'PostTrainingQuantization', 'WeightQuantization',
'PostTrainingQuantizationProgram'
]

_logger = get_logger(__name__,
logging.INFO,
Expand Down Expand Up @@ -108,9 +113,9 @@ class PostTrainingQuantization(object):
"""

def __init__(self,
executor=None,
executor,
model_dir,
scope=None,
model_dir=None,
model_filename=None,
params_filename=None,
batch_generator=None,
Expand All @@ -130,10 +135,15 @@ def __init__(self,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
onnx_format=False,
freeze_model=True,
optimize_model=False,
is_use_cache_file=False,
skip_tensor_list=None,
cache_dir=None):
same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None,
scale_dict=None,
return_graph=False):
'''
Constructor.
Expand Down Expand Up @@ -206,7 +216,12 @@ def __init__(self,
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
onnx_format(bool): Whether to export the quantized model with format of ONNX.
Default is False.
skip_tensor_list(list): List of skip quant tensor name.
freeze_model(bool): Whether to convert quantized and trained ``program`` to final
quantized ``program``. Default: True.
skip_tensor_list(list): List of skip quant tensor name. Default: None.
same_scale_tensor_list(list(list)): The list of tensor keep same scale in the outermost
list, the final scale about every list is the max of the scale in the list
of tensor. Default: None.
optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
Expand All @@ -215,6 +230,7 @@ def __init__(self,
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before
quantization. Default False.
scale_trainable(bool, optional): whether scale can be train.
is_use_cache_file(bool, optional): This param is deprecated.
cache_dir(str, optional): This param is deprecated.
Returns:
Expand Down Expand Up @@ -275,7 +291,6 @@ def __init__(self,

# Check inputs
assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None."
assert any([gen is not None] for gen in [sample_generator,
batch_generator, data_loader]), "The sample_generator, batch_generator " \
"and data_loader cannot be None in the same time."
Expand Down Expand Up @@ -347,6 +362,11 @@ def __init__(self,
self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}
self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model
self._scale_trainable = scale_trainable
self._scale_dict = scale_dict
self._return_graph = return_graph

def quantize(self):
'''
Expand Down Expand Up @@ -441,7 +461,11 @@ def quantize(self):
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)

return self._program
if not self._return_graph:
return self._program
else:
main_graph = IrGraph(core.Graph(self._program.desc), for_test=True)
return main_graph

def _adaround_apply(self):
assert self._algo != "min_max", "The algo should not be min_max."
Expand Down Expand Up @@ -495,12 +519,13 @@ def _load_model_data(self):
'''
Load model and set data loader.
'''
_logger.info("Load model and set data loader ...")
[self._program, self._feed_list, self._fetch_list] = \
io.load_inference_model(dirname=self._model_dir,
executor=self._executor,
model_filename=self._model_filename,
params_filename=self._params_filename)
if self._program is None:
_logger.info("Load model and set data loader ...")
[self._program, self._feed_list, self._fetch_list] = \
io.load_inference_model(dirname=self._model_dir,
executor=self._executor,
model_filename=self._model_filename,
params_filename=self._params_filename)

if self._optimize_model:
self._optimize_fp32_model()
Expand Down Expand Up @@ -972,7 +997,8 @@ def _update_program(self):
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
quantizable_op_type=major_quantizable_op_types,
is_test=not self._scale_trainable)
else:
transform_pass = QuantizationTransformPassV2(
scope=self._scope,
Expand All @@ -981,7 +1007,8 @@ def _update_program(self):
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
quantizable_op_type=major_quantizable_op_types,
is_test=not self._scale_trainable)

for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
Expand All @@ -998,24 +1025,68 @@ def _update_program(self):
add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types)
quantizable_op_type=minor_quantizable_op_types,
is_test=not self._scale_trainable)
else:
add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize)
is_full_quantized=self._is_full_quantize,
is_test=not self._scale_trainable)

for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph)

# save threshold to scale var node
if self._algo in ["KL", "hist"]:
scale_dict = self._quantized_var_threshold
else:
scale_dict = self._quantized_threshold
for key, val in scale_dict.items():
if self._scale_dict is None:
if self._algo in ["KL", "hist"]:
scale_dict = self._quantized_var_threshold
else:
scale_dict = self._quantized_threshold

if self._same_scale_tensor_list is not None:
for tensor_list in self._same_scale_tensor_list:
max_scale = None
tmp_tensor_list = []
for tensor_name in tensor_list:
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
'#')
if opera == '*':
scale_dict[real_tensor_name] = float(
scale_dict[real_tensor_name]) * float(
scalar)
elif opera == '/':
scale_dict[real_tensor_name] = float(
scale_dict[real_tensor_name]) / float(
scalar)
max_scale = scale_dict[
real_tensor_name] if max_scale is None else max(
max_scale, scale_dict[real_tensor_name])
else:
max_scale = scale_dict[
tensor_name] if max_scale is None else max(
max_scale, scale_dict[tensor_name])

for tensor_name in tensor_list:
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
'#')
if opera == '*':
scale_dict[
real_tensor_name] = max_scale / float(
scalar)
elif opera == '/':
scale_dict[
real_tensor_name] = max_scale * float(
scalar)
else:
scale_dict[tensor_name] = max_scale
self._scale_dict = scale_dict

for key, val in self._scale_dict.items():
utils.set_variable_data(self._scope, self._place, key + "@scale",
np.array([val], dtype=np.float32))
utils.set_variable_data(self._scope, self._place,
Expand All @@ -1024,19 +1095,20 @@ def _update_program(self):

if not self._onnx_format:
# apply QuantizationFreezePass, and obtain the final quant model
freeze_pass = QuantizationFreezePass(
scope=self._scope,
place=self._place,
bias_correction=self._bias_correction,
weight_bits=self._weight_bits,
round_type=self._round_type,
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)

for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
freeze_pass.apply(sub_graph)
if self._freeze_model:
freeze_pass = QuantizationFreezePass(
scope=self._scope,
place=self._place,
bias_correction=self._bias_correction,
weight_bits=self._weight_bits,
round_type=self._round_type,
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)

for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
freeze_pass.apply(sub_graph)
else:
quant_weight_pass = QuantWeightPass(self._scope, self._place)
for sub_graph in graph.all_sub_graphs():
Expand Down Expand Up @@ -1155,6 +1227,58 @@ def _get_hist_scaling_factor(self, hist, hist_edges):
return (hist_index - 0.5) * bin_width


class PostTrainingQuantizationProgram(PostTrainingQuantization):

def __init__(self,
executor,
program,
feed_list=None,
fetch_list=None,
scope=None,
batch_generator=None,
sample_generator=None,
data_loader=None,
batch_size=10,
batch_nums=None,
algo="KL",
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
round_type='round',
learning_rate=0.001,
is_full_quantize=False,
bias_correction=False,
activation_bits=8,
weight_bits=8,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
onnx_format=False,
freeze_model=True,
optimize_model=False,
is_use_cache_file=False,
skip_tensor_list=None,
same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None,
scale_dict=None,
return_graph=True):
super().__init__(executor, scope, None, None, None, batch_generator,
sample_generator, data_loader, batch_size, batch_nums,
algo, hist_percent, quantizable_op_type, round_type,
learning_rate, is_full_quantize, bias_correction,
activation_bits, weight_bits, activation_quantize_type,
weight_quantize_type, onnx_format, freeze_model,
optimize_model, is_use_cache_file, skip_tensor_list,
same_scale_tensor_list, scale_trainable, cache_dir,
scale_dict, return_graph)
self._program = program
assert feed_list is not None, \
"Feed list should not be None."
assert fetch_list is not None, \
"Fetch list should not be None."
self._feed_list = feed_list
self._fetch_list = fetch_list


class WeightQuantization(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
_supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max']
Expand Down
Loading

0 comments on commit 9ac27ac

Please sign in to comment.