Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add silent option to quantization script (#17094)
Browse files Browse the repository at this point in the history
* Add silent option to quantization script

* Refactor code

* Fix lint
  • Loading branch information
ZhennanQin authored and zhreshold committed Dec 18, 2019
1 parent ed09547 commit a18250d
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 69 deletions.
70 changes: 44 additions & 26 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,31 +140,40 @@ def save_params(fname, arg_params, aux_params, logger=None):
help='If enabled, the quantize op will '
'be calibrated offline if calibration mode is '
'enabled')
parser.add_argument('--quiet', action='store_true', default=False,
help='suppress most of log')
args = parser.parse_args()
ctx = mx.cpu(0)
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
logger = None
if not args.quiet:
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)

logger.info(args)
logger.info('shuffle_dataset=%s' % args.shuffle_dataset)
if logger:
logger.info(args)
logger.info('shuffle_dataset=%s' % args.shuffle_dataset)

calib_mode = args.calib_mode
logger.info('calibration mode set to %s' % calib_mode)
if logger:
logger.info('calibration mode set to %s' % calib_mode)

# download calibration dataset
if calib_mode != 'none':
download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset)

# download model
if not args.no_pretrained:
logger.info('Get pre-trained model from MXNet or Gluoncv modelzoo.')
logger.info('If you want to use custom model, please set --no-pretrained.')
if logger:
logger.info('Get pre-trained model from MXNet or Gluoncv modelzoo.')
logger.info('If you want to use custom model, please set --no-pretrained.')
if args.model in ['imagenet1k-resnet-152', 'imagenet1k-inception-bn']:
logger.info('model %s is downloaded from MXNet modelzoo' % args.model)
if logger:
logger.info('model %s is downloaded from MXNet modelzoo' % args.model)
prefix, epoch = download_model(model_name=args.model, logger=logger)
else:
logger.info('model %s is converted from GluonCV' % args.model)
if logger:
logger.info('model %s is converted from GluonCV' % args.model)
prefix = convert_from_gluon(model_name=args.model, image_shape=args.image_shape, classes=1000, logger=logger)
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
Expand All @@ -178,14 +187,16 @@ def save_params(fname, arg_params, aux_params, logger=None):

# get batch size
batch_size = args.batch_size
logger.info('batch size = %d for calibration' % batch_size)
if logger:
logger.info('batch size = %d for calibration' % batch_size)

# get number of batches for calibration
num_calib_batches = args.num_calib_batches
if calib_mode == 'none':
logger.info('skip calibration step as calib_mode is none')
else:
logger.info('number of batches = %d for calibration' % num_calib_batches)
if logger:
if calib_mode == 'none':
logger.info('skip calibration step as calib_mode is none')
else:
logger.info('number of batches = %d for calibration' % num_calib_batches)

# get number of threads for decoding the dataset
data_nthreads = args.data_nthreads
Expand All @@ -195,7 +206,8 @@ def save_params(fname, arg_params, aux_params, logger=None):

exclude_first_conv = args.exclude_first_conv
if args.quantized_dtype == "uint8":
logger.info('quantized dtype is set to uint8, will exclude first conv.')
if logger:
logger.info('quantized dtype is set to uint8, will exclude first conv.')
exclude_first_conv = True
excluded_sym_names = []
if not args.no_pretrained:
Expand Down Expand Up @@ -242,42 +254,48 @@ def save_params(fname, arg_params, aux_params, logger=None):
else:
raise ValueError('Currently, model %s is not supported in this script' % args.model)
else:
logger.info('Please set proper RGB configs for model %s' % args.model)
if logger:
logger.info('Please set proper RGB configs for model %s' % args.model)
# add rgb mean/std of your model.
rgb_mean = '0,0,0'
rgb_std = '0,0,0'
# add layer names you donnot want to quantize.
logger.info('Please set proper excluded_sym_names for model %s' % args.model)
if logger:
logger.info('Please set proper excluded_sym_names for model %s' % args.model)
excluded_sym_names += ['layers']
if exclude_first_conv:
excluded_sym_names += ['layers']

logger.info('These layers have been excluded %s' % excluded_sym_names)
if logger:
logger.info('These layers have been excluded %s' % excluded_sym_names)

label_name = args.label_name
logger.info('label_name = %s' % label_name)
if logger:
logger.info('label_name = %s' % label_name)

data_shape = tuple([int(i) for i in image_shape.split(',')])
logger.info('Input data shape = %s' % str(data_shape))

logger.info('rgb_mean = %s' % rgb_mean)
if logger:
logger.info('Input data shape = %s' % str(data_shape))
logger.info('rgb_mean = %s' % rgb_mean)
logger.info('rgb_std = %s' % rgb_std)
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
logger.info('rgb_std = %s' % rgb_std)
rgb_std = [float(i) for i in rgb_std.split(',')]
std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]}
combine_mean_std = {}
combine_mean_std.update(mean_args)
combine_mean_std.update(std_args)
if calib_mode == 'none':
logger.info('Quantizing FP32 model %s' % args.model)
if logger:
logger.info('Quantizing FP32 model %s' % args.model)
qsym, qarg_params, aux_params = quantize_model_mkldnn(sym=sym, arg_params=arg_params, aux_params=aux_params,
ctx=ctx, excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, quantized_dtype=args.quantized_dtype,
logger=logger)
sym_name = '%s-symbol.json' % (prefix + '-quantized')
else:
logger.info('Creating ImageRecordIter for reading calibration dataset')
if logger:
logger.info('Creating ImageRecordIter for reading calibration dataset')
data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset,
label_width=1,
preprocess_threads=data_nthreads,
Expand Down
90 changes: 54 additions & 36 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def collect(self, name, arr):
return
handle = ctypes.cast(arr, NDArrayHandle)
arr = NDArray(handle, writable=False).copyto(cpu()).asnumpy()
if self.logger is not None:
self.logger.info("Collecting layer %s histogram of shape %s" % (name, arr.shape))
if self.logger:
self.logger.debug("Collecting layer %s histogram of shape %s" % (name, arr.shape))
min_range = np.min(arr)
max_range = np.max(arr)
th = max(abs(min_range), abs(max_range))
Expand Down Expand Up @@ -224,9 +224,9 @@ def collect(self, name, arr):
max(cur_min_max[1], max_range))
else:
self.min_max_dict[name] = (min_range, max_range)
if self.logger is not None:
self.logger.info("Collecting layer %s min_range=%f, max_range=%f"
% (name, min_range, max_range))
if self.logger:
self.logger.debug("Collecting layer %s min_range=%f, max_range=%f"
% (name, min_range, max_range))

def _calibrate_quantized_sym(qsym, th_dict):
"""Given a dictionary containing the thresholds for quantizing the layers,
Expand Down Expand Up @@ -358,18 +358,19 @@ def _get_optimal_thresholds(hist_dict, quantized_dtype, num_quantized_bins=255,
else:
th_dict[name] = (-th, th)
del hist_dict[name] # release the memory
if logger is not None:
logger.info('layer=%s, min_val=%f, max_val=%f, th=%f, divergence=%f'
% (name, min_val, max_val, th, divergence))
if logger:
logger.debug('layer=%s, min_val=%f, max_val=%f, th=%f, divergence=%f'
% (name, min_val, max_val, th, divergence))
return th_dict


def _load_sym(sym, logger=logging):
def _load_sym(sym, logger=None):
"""Given a str as a path the symbol .json file or a symbol, returns a Symbol object."""
if isinstance(sym, str): # sym is a symbol file path
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, sym)
logger.info('Loading symbol from file %s' % symbol_file_path)
if logger:
logger.info('Loading symbol from file %s' % symbol_file_path)
return sym_load(symbol_file_path)
elif isinstance(sym, Symbol):
return sym
Expand All @@ -378,14 +379,15 @@ def _load_sym(sym, logger=logging):
' while received type %s' % str(type(sym)))


def _load_params(params, logger=logging):
def _load_params(params, logger=None):
"""Given a str as a path to the .params file or a pair of params,
returns two dictionaries representing arg_params and aux_params.
"""
if isinstance(params, str):
cur_path = os.path.dirname(os.path.realpath(__file__))
param_file_path = os.path.join(cur_path, params)
logger.info('Loading params from file %s' % param_file_path)
if logger:
logger.info('Loading params from file %s' % param_file_path)
save_dict = nd_load(param_file_path)
arg_params = {}
aux_params = {}
Expand Down Expand Up @@ -451,7 +453,7 @@ def quantize_model(sym, arg_params, aux_params,
data_names=('data',), label_names=('softmax_label',),
ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy',
calib_data=None, num_calib_examples=None,
quantized_dtype='int8', quantize_mode='smart', logger=logging):
quantized_dtype='int8', quantize_mode='smart', logger=None):
"""User-level API for generating a quantized model from a FP32 model w/ or w/o calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
inference using the quantized models on Windows for now.
Expand Down Expand Up @@ -530,7 +532,9 @@ def quantize_model(sym, arg_params, aux_params,
' the names of the operators that will not be quantized,'
' while received type %s' % str(type(excluded_op_names)))

logger.info('Quantizing symbol')
if logger:
os.environ['MXNET_QUANTIZATION_VERBOSE'] = '1'
logger.info('Quantizing symbol')
if quantized_dtype not in ('int8', 'uint8', 'auto'):
raise ValueError('unknown quantized_dtype %s received,'
' expected `int8`, `uint8` or `auto`' % quantized_dtype)
Expand Down Expand Up @@ -561,21 +565,24 @@ def quantize_model(sym, arg_params, aux_params,
include_layer=calib_layer,
max_num_examples=num_calib_examples,
logger=logger)
logger.info('Collected layer outputs from FP32 model using %d examples' % num_examples)
logger.info('Calculating optimal thresholds for quantization')
if logger:
logger.info('Collected layer outputs from FP32 model using %d examples' % num_examples)
logger.info('Calculating optimal thresholds for quantization')
th_dict = _get_optimal_thresholds(hist_dict, quantized_dtype, logger=logger)
elif calib_mode == 'naive':
th_dict, num_examples = _collect_layer_output_min_max(
mod, calib_data, quantized_dtype, include_layer=calib_layer, max_num_examples=num_calib_examples,
logger=logger)
logger.info('Collected layer output min/max values from FP32 model using %d examples'
% num_examples)
if logger:
logger.info('Collected layer output min/max values from FP32 model using %d examples'
% num_examples)
else:
raise ValueError('unknown calibration mode %s received,'
' expected `none`, `naive`, or `entropy`' % calib_mode)
qsym = _calibrate_quantized_sym(qsym, th_dict)

logger.info('Quantizing parameters')
if logger:
logger.info('Quantizing parameters')
qarg_params = _quantize_params(qsym, arg_params, th_dict)

return qsym, qarg_params, aux_params
Expand All @@ -584,7 +591,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params,
data_names=('data',), label_names=('softmax_label',),
ctx=cpu(), excluded_sym_names=None, excluded_op_names=None,
calib_mode='entropy', calib_data=None, num_calib_examples=None,
quantized_dtype='int8', quantize_mode='smart', logger=logging):
quantized_dtype='int8', quantize_mode='smart', logger=None):
"""User-level API for generating a fusion + quantized model from a FP32 model
w/ or w/o calibration with Intel MKL-DNN.
The backend quantized operators are only enabled for Linux systems. Please do not run
Expand Down Expand Up @@ -621,7 +628,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params,

def quantize_graph(sym, arg_params, aux_params, ctx=cpu(),
excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy',
quantized_dtype='int8', quantize_mode='full', logger=logging):
quantized_dtype='int8', quantize_mode='full', logger=None):
"""User-level API for generating a quantized model from a FP32 model w/o calibration
and a collector for naive or entropy calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
Expand Down Expand Up @@ -676,7 +683,9 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(),
' while received type %s' % str(type(excluded_sym_names)))
if not isinstance(ctx, Context):
raise ValueError('currently only supports single ctx, while received %s' % str(ctx))
logger.info('Quantizing graph')
if logger:
os.environ['MXNET_QUANTIZATION_VERBOSE'] = '1'
logger.info('Quantizing graph')
if quantized_dtype not in ('int8', 'uint8', 'auto'):
raise ValueError('unknown quantized_dtype %s received,'
' expected `int8`, `uint8` or `auto`' % quantized_dtype)
Expand All @@ -693,20 +702,24 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(),
if calib_mode == 'entropy':
collector = _LayerHistogramCollector(
include_layer=calib_layer, logger=logger)
logger.info(
'Create a layer output collector for entropy calibration.')
if logger:
logger.info(
'Create a layer output collector for entropy calibration.')
elif calib_mode == 'naive':
collector = _LayerOutputMinMaxCollector(quantized_dtype=quantized_dtype,
include_layer=calib_layer, logger=logger)
logger.info(
'Create a layer output minmax collector for naive calibration')
if logger:
logger.info(
'Create a layer output minmax collector for naive calibration')
else:
raise ValueError('unknown calibration mode %s received,'
' expected `none`, `naive`, or `entropy`' % calib_mode)
logger.info('Collector created, please use set_monitor_callback'
' to collect calibration information.')
if logger:
logger.info('Collector created, please use set_monitor_callback'
' to collect calibration information.')

logger.info('Quantizing parameters')
if logger:
logger.info('Quantizing parameters')
qarg_params = _quantize_params(qsym, arg_params, th_dict)

return qsym, qarg_params, aux_params, collector
Expand Down Expand Up @@ -751,7 +764,8 @@ def calib_graph(qsym, arg_params, aux_params, collector,
th_dict = {}
if calib_mode is not None and calib_mode != 'none':
if calib_mode == 'entropy':
logger.info('Calculating optimal thresholds for quantization')
if logger:
logger.info('Calculating optimal thresholds for quantization')
th_dict = _get_optimal_thresholds(
collector.hist_dict, quantized_dtype, logger=logger)
elif calib_mode == 'naive':
Expand All @@ -763,15 +777,16 @@ def calib_graph(qsym, arg_params, aux_params, collector,
else:
raise ValueError('please set calibration mode to naive or entropy.')

logger.info('Quantizing parameters')
if logger:
logger.info('Quantizing parameters')
qarg_params = _quantize_params(qsym, arg_params, th_dict)

return qsym, qarg_params, aux_params

def quantize_net(network, quantized_dtype='auto', quantize_mode='full',
exclude_layers=None, exclude_layers_match=None, exclude_operators=None,
calib_data=None, data_shapes=None, calib_mode='none',
num_calib_examples=None, ctx=cpu(), logger=logging):
num_calib_examples=None, ctx=cpu(), logger=None):
"""User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
inference using the quantized models on Windows for now.
Expand Down Expand Up @@ -825,7 +840,8 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full',
-------
"""

logger.info('Export HybridBlock')
if logger:
logger.info('Export HybridBlock')
network.hybridize()
import mxnet as mx
if calib_data is not None:
Expand Down Expand Up @@ -881,7 +897,8 @@ def __exit__(self, exc_type, exc_value, traceback):
for layers in list(symnet.get_internals()):
if layers.name.find(name_match) != -1:
exclude_layers.append(layers.name)
logger.info('These layers have been excluded %s' % exclude_layers)
if logger:
logger.info('These layers have been excluded %s' % exclude_layers)

if ctx == mx.cpu():
symnet = symnet.get_backend_symbol('MKLDNN_QUANTIZE')
Expand All @@ -906,8 +923,9 @@ def __exit__(self, exc_type, exc_value, traceback):
mod.set_params(args, auxs, allow_missing=False, force_init=True)
num_examples = _collect_layer_statistics(mod, calib_data, collector,
num_calib_examples, logger)
logger.info('Collected layer output values from FP32 model using %d examples'
% num_examples)
if logger:
logger.info('Collected layer output values from FP32 model using %d examples'
% num_examples)
qsym, qarg_params, aux_params = calib_graph(
qsym=qsym, arg_params=args, aux_params=auxs, collector=collector,
calib_mode=calib_mode, quantized_dtype=quantized_dtype, logger=logger)
Expand Down
Loading

0 comments on commit a18250d

Please sign in to comment.