From 0eddcc229f303a6631b85941adec09ffdd639cf6 Mon Sep 17 00:00:00 2001 From: chenxiny Date: Tue, 21 May 2019 20:08:25 +0800 Subject: [PATCH 01/17] enhance quantization api --- .../quantization/imagenet_gen_qsym_mkldnn.py | 46 ++++-- python/mxnet/contrib/quantization.py | 156 ++++++++++++++++++ 2 files changed, 191 insertions(+), 11 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 482127ba355c..88def06c2268 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -279,12 +279,17 @@ def save_params(fname, arg_params, aux_params, logger=None): combine_mean_std = {} combine_mean_std.update(mean_args) combine_mean_std.update(std_args) + logger.info('Quantizing FP32 model %s' % args.model) + qsym, qarg_params, aux_params, collector = quantize_graph(sym=sym, arg_params=arg_params, aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + calib_mode=calib_mode, calib_layer=calib_layer, + quantized_dtype=args.quantized_dtype, logger=logger) if calib_mode == 'none': - logger.info('Quantizing FP32 model %s' % args.model) - qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, - ctx=ctx, excluded_sym_names=excluded_sym_names, - calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, - logger=logger) + # logger.info('Quantizing FP32 model %s' % args.model) + # qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + # ctx=ctx, excluded_sym_names=excluded_sym_names, + # calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, + # logger=logger) sym_name = '%s-symbol.json' % (prefix + '-quantized') else: logger.info('Creating ImageRecordIter for reading calibration dataset') @@ -301,12 +306,31 @@ def save_params(fname, arg_params, aux_params, logger=None): seed=args.shuffle_seed, **combine_mean_std) - qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, - ctx=ctx, excluded_sym_names=excluded_sym_names, - calib_mode=calib_mode, calib_data=data, - num_calib_examples=num_calib_batches * batch_size, - calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, - label_names=(label_name,), logger=logger) + # qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + # ctx=ctx, excluded_sym_names=excluded_sym_names, + # calib_mode=calib_mode, calib_data=data, + # num_calib_examples=num_calib_batches * batch_size, + # calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, + # label_names=(label_name,), logger=logger) + mod = mx.mod.Module(symbol=sym, label_names=('softmax_label',), context=ctx) + mod.bind(for_training=False, data_shapes=data.provide_data, label_shapes=data.provide_label) + mod.set_params(arg_params, aux_params) + mod._exec_group.execs[0].set_monitor_callback(collector.collect, monitor_all=True) + num_batches = 0 + num_examples = 0 + max_num_examples = num_calib_batches * batch_size + for batch in data: + mod.forward(data_batch=batch, is_train=False) + num_batches += 1 + num_examples += batch_size + if num_examples >= max_num_examples: + break + if logger is not None: + logger.info("Collected statistics from %d batches with batch_size=%d" + % (num_batches, batch_size)) + qsym, qarg_params, aux_params = calib_graph(qsym=qsym, arg_params=arg_params, aux_params=aux_params, + collector=collector, calib_mode=calib_mode, + quantized_dtype=args.quantized_dtype, logger=logger) if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches elif calib_mode == 'naive': diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index b94b5a8da32a..ae3dcae229d9 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -543,3 +543,159 @@ def quantize_model(sym, arg_params, aux_params, qarg_params = _quantize_params(qsym, arg_params, th_dict) return qsym, qarg_params, aux_params + +def quantize_graph(sym, arg_params, aux_params, + excluded_sym_names=None, calib_mode='entropy', + calib_layer=None, quantized_dtype='int8', logger=logging): + """User-level API for generating a quantized model from a FP32 model w/o calibration + and a collector for naive or entropy calibration. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + The quantization implementation adopts the TensorFlow's approach: + https://www.tensorflow.org/performance/quantization. + The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + and adapts the method to MXNet. + + Parameters + ---------- + sym : str or Symbol + Defines the structure of a neural network for FP32 data types. + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + excluded_sym_names : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + calib_layer : function + Given a layer's output name in string, return True or False for deciding whether to + calibrate this layer. If yes, the statistics of the layer's output will be collected; + otherwise, no information of the layer's output will be collected. If not provided, + all the layers' outputs that need requantization will be collected. + quantized_dtype : str + The quantized destination type for input data. Currently support 'int8' + , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. + logger : Object + A logging object for printing information during the process of quantization. + + Returns + ------- + tuple + A tuple of quantized symbol, quantized arg_params, aux_params and collector. + ------- + """ + if excluded_sym_names is None: + excluded_sym_names = [] + if not isinstance(excluded_sym_names, list): + raise ValueError('excluded_sym_names must be a list of strings representing' + ' the names of the symbols that will not be quantized,' + ' while received type %s' % str(type(excluded_sym_names))) + + logger.info('Quantizing graph') + if quantized_dtype not in ('int8', 'uint8', 'auto'): + raise ValueError('unknown quantized_dtype %s received,' + ' expected `int8`, `uint8` or `auto`' % quantized_dtype) + qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, + offline_params=list(arg_params.keys()), + quantized_dtype=quantized_dtype) + + th_dict = {} + collector = None + if calib_mode is not None and calib_mode != 'none': + if calib_mode == 'entropy': + collector = _LayerOutputCollector(include_layer=calib_layer, logger=logger) + logger.info('Create a layer output collector for entropy calibration.') + elif calib_mode == 'naive': + collector = _LayerOutputMinMaxCollector(include_layer=calib_layer, logger=logger) + logger.info('Create a layer output minmax collector for naive calibration') + else: + raise ValueError('unknown calibration mode %s received,' + ' expected `none`, `naive`, or `entropy`' % calib_mode) + logger.info('Collector created, please use set_monitor_callback' + ' to collect calibration information.') + + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params, th_dict) + + return qsym, qarg_params, aux_params, collector + +def calib_graph(qsym, arg_params, aux_params, collector, + calib_mode='entropy', quantized_dtype='int8', logger=logging): + """User-level API for calibrating a quantized model using a filled collector. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + The quantization implementation adopts the TensorFlow's approach: + https://www.tensorflow.org/performance/quantization. + The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + and adapts the method to MXNet. + + Parameters + ---------- + qsym : str or Symbol + Defines the structure of a neural network for INT8 data types. + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + collector : function + layer collector for naive or entropy calibration. + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + calib_layer : function + Given a layer's output name in string, return True or False for deciding whether to + calibrate this layer. If yes, the statistics of the layer's output will be collected; + otherwise, no information of the layer's output will be collected. If not provided, + all the layers' outputs that need requantization will be collected. + quantized_dtype : str + The quantized destination type for input data. Currently support 'int8' + , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. + logger : Object + A logging object for printing information during the process of quantization. + + Returns + ------- + tuple + A tuple of calibrated symbol, quantized arg_params, aux_params. + ------- + """ + th_dict = {} + if calib_mode is not None and calib_mode != 'none': + if calib_mode == 'entropy': + logger.info('Calculating optimal thresholds for quantization') + th_dict = _get_optimal_thresholds(collector.nd_dict, quantized_dtype, logger=logger) + elif calib_mode == 'naive': + th_dict = collector.min_max_dict + else: + raise ValueError('unknown calibration mode %s received,' + ' expected `none`, `naive`, or `entropy`' % calib_mode) + logger.info('Calibrating quantized symbol') + qsym = _calibrate_quantized_sym(qsym, th_dict) + else: + raise ValueError('please set calibration mode to naive or entropy.') + + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params, th_dict) + + return qsym, qarg_params, aux_params From a8a41494c9c1192e2a2685332ebdeafcbaf5cf63 Mon Sep 17 00:00:00 2001 From: chenxiny Date: Mon, 3 Jun 2019 16:23:31 +0800 Subject: [PATCH 02/17] integrate gluoncv solution --- python/mxnet/contrib/quantization.py | 94 ++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index ae3dcae229d9..4d127a6d3c1c 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -32,8 +32,11 @@ from ..base import NDArrayHandle, SymbolHandle from ..symbol import Symbol from ..symbol import load as sym_load +from ..symbol import load_json as sym_load_json from .. import ndarray from ..ndarray import load as nd_load +from ..ndarray import save as nd_save +from ..ndarray import zeros as zeros from ..ndarray import NDArray from ..io import DataIter from ..context import cpu, Context @@ -699,3 +702,94 @@ def calib_graph(qsym, arg_params, aux_params, collector, qarg_params = _quantize_params(qsym, arg_params, th_dict) return qsym, qarg_params, aux_params + +def static_net_forward(sym, arg_params, aux_params, collector, + calib_data, data_shapes, ctx=cpu(), logger=logging): + """symbolic forward a static model with gluon dataset. + """ + num_batches = len(calib_data) + num_inputs = len(data_shapes) + data_names = () + data_shapes_ = [] + for i in range(num_inputs): + data_names = data_names + ('data' + str(i),) + data_shapes_.append((data_names[i], data_shapes[i])) + mod = Module(symbol=sym, context=ctx, + data_names=data_names, label_names=None) + mod.bind(for_training=False, data_shapes=data_shapes_) + mod.set_params(arg_params, aux_params, + allow_missing=False, force_init=True) + mod._exec_group.execs[0].set_monitor_callback( + collector.collect, monitor_all=True) + for batch in calib_data: + mod.forward(data_batch=batch, is_train=False) + if logger is not None: + logger.info("Collected statistics from %d batches" + % num_batches) + return collector, data_names + +def save_params(fname, arg_params, aux_params, logger=logging): + if logger is not None: + logger.info('Saving tmp params into file at %s' % fname) + save_dict = {('arg:%s' % k): v.as_in_context(cpu()) + for k, v in arg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) + for k, v in aux_params.items()}) + nd_save(fname, save_dict) + +def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_data=None, + data_shapes=None, calib_mode='none', ctx=cpu(), logger=logging): + """quantize a gluon net + """ + + logger.info('Export symbolblock') + network.hybridize() + import mxnet as mx + data_sym = [] + for i in range(len(data_shapes)): + data_sym.append(mx.sym.var('data'+str(i))) + symnet = sym_load_json(network(*data_sym).tojson()) + params = network.collect_params() + args = {} + auxs = {} + for param in params.values(): + v = param._reduce() + k = param.name + if 'running' in k: + auxs[k] = v + else: + args[k] = v + + logger.info('Exclude all fc layers') + if exclude_layers is None: + exclude_layers = [] + for layers in list(symnet.get_internals()): + if layers.name.find('dense') != -1: + exclude_layers.append(layers.name) + symnet = symnet.get_backend_symbol('MKLDNN') + qsym, qarg_params, aux_params, collector = quantize_graph(sym=symnet, arg_params=args, aux_params=auxs, + excluded_sym_names=exclude_layers, calib_mode=calib_mode, + calib_layer=None, quantized_dtype=quantized_dtype, logger=logger) + + if calib_mode is not None and calib_mode != 'none': + if calib_mode in ['naive', 'entropy']: + collector, data_names = static_net_forward(sym=symnet, arg_params=args, aux_params=auxs, + collector=collector, calib_data=calib_data, + data_shapes=data_shapes, ctx=cpu(), logger=logger) + qsym, qarg_params, aux_params = calib_graph(qsym=qsym, arg_params=args, aux_params=auxs, + collector=collector, calib_mode=calib_mode, + quantized_dtype=quantized_dtype, logger=logger) + logger.info('Calibrating quantized symbol') + else: + raise ValueError('please set calibration mode to naive or entropy.') + qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') + + from ..gluon import SymbolBlock + net = SymbolBlock(qsym, data_sym) + import tempfile + with tempfile.TemporaryDirectory() as tmpdirname: + prefix = os.path.join(tmpdirname, 'tmp') + param_name = '%s-%04d.params' % (prefix + 'net-quantized', 0) + save_params(param_name, qarg_params, aux_params, logger) + net.collect_params().load(param_name) + return net From 322898ffaa78ad83a213102d9db7d03069620a29 Mon Sep 17 00:00:00 2001 From: chenxiny Date: Wed, 5 Jun 2019 22:11:24 +0800 Subject: [PATCH 03/17] support gluon ssd --- python/mxnet/contrib/quantization.py | 42 +++++++++++++--------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 4d127a6d3c1c..e2e3bd657662 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -711,9 +711,13 @@ def static_net_forward(sym, arg_params, aux_params, collector, num_inputs = len(data_shapes) data_names = () data_shapes_ = [] - for i in range(num_inputs): - data_names = data_names + ('data' + str(i),) - data_shapes_.append((data_names[i], data_shapes[i])) + if num_inputs == 1: + data_names = ('data',) + data_shapes_ = [(data_names[0], data_shapes[0])] + else: + for i in range(num_inputs): + data_names = data_names + ('data' + str(i),) + data_shapes_.append((data_names[i], data_shapes[i])) mod = Module(symbol=sym, context=ctx, data_names=data_names, label_names=None) mod.bind(for_training=False, data_shapes=data_shapes_) @@ -729,8 +733,6 @@ def static_net_forward(sym, arg_params, aux_params, collector, return collector, data_names def save_params(fname, arg_params, aux_params, logger=logging): - if logger is not None: - logger.info('Saving tmp params into file at %s' % fname) save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()} save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) @@ -745,26 +747,20 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_dat logger.info('Export symbolblock') network.hybridize() import mxnet as mx - data_sym = [] - for i in range(len(data_shapes)): - data_sym.append(mx.sym.var('data'+str(i))) - symnet = sym_load_json(network(*data_sym).tojson()) - params = network.collect_params() - args = {} - auxs = {} - for param in params.values(): - v = param._reduce() - k = param.name - if 'running' in k: - auxs[k] = v - else: - args[k] = v - + data_nd = [] + for shape in data_shapes: + data_nd.append(mx.nd.zeros(shape)) + network(*data_nd) + import tempfile + with tempfile.TemporaryDirectory() as tmpdirname: + prefix = os.path.join(tmpdirname, 'tmp') + network.export(prefix, epoch=0) + symnet, args, auxs = mx.model.load_checkpoint(prefix, 0) logger.info('Exclude all fc layers') if exclude_layers is None: exclude_layers = [] for layers in list(symnet.get_internals()): - if layers.name.find('dense') != -1: + if layers.name.find('dense') != -1 or layers.name.find('flatten') != -1: exclude_layers.append(layers.name) symnet = symnet.get_backend_symbol('MKLDNN') qsym, qarg_params, aux_params, collector = quantize_graph(sym=symnet, arg_params=args, aux_params=auxs, @@ -785,8 +781,10 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_dat qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') from ..gluon import SymbolBlock + data_sym = [] + for name in data_names: + data_sym.append(mx.sym.var(name)) net = SymbolBlock(qsym, data_sym) - import tempfile with tempfile.TemporaryDirectory() as tmpdirname: prefix = os.path.join(tmpdirname, 'tmp') param_name = '%s-%04d.params' % (prefix + 'net-quantized', 0) From ef3cdde6dec637e5fa573915344499e2acf3225d Mon Sep 17 00:00:00 2001 From: chenxiny Date: Thu, 6 Jun 2019 13:06:33 +0800 Subject: [PATCH 04/17] enhance api --- python/mxnet/contrib/quantization.py | 143 ++++++++++++++++++++++----- 1 file changed, 116 insertions(+), 27 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index e2e3bd657662..dc5c0585dfec 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -705,7 +705,32 @@ def calib_graph(qsym, arg_params, aux_params, collector, def static_net_forward(sym, arg_params, aux_params, collector, calib_data, data_shapes, ctx=cpu(), logger=logging): - """symbolic forward a static model with gluon dataset. + """Symbolic forward a symbol converted from gluon HybridBlock to use monitor API. + + Parameters + ---------- + sym : str or Symbol + Defines the structure of a neural network for FP32 data types. + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + collector : function + Layer collector for naive or entropy calibration. + calib_data : list + A list containing several batches of input data for calibration. + data_shapes : list + A list containing shapes of input data for symbolic bind. + logger : Object + A logging object for printing information during the process of quantization. + + Returns + ------- + collector : function + Layer collector containing statistical information of each layer. + data_names : tuple + A tuple containing data names for gluon SymbolBlock import. + ------- """ num_batches = len(calib_data) num_inputs = len(data_shapes) @@ -714,10 +739,12 @@ def static_net_forward(sym, arg_params, aux_params, collector, if num_inputs == 1: data_names = ('data',) data_shapes_ = [(data_names[0], data_shapes[0])] - else: + elif num_inputs > 1: for i in range(num_inputs): data_names = data_names + ('data' + str(i),) data_shapes_.append((data_names[i], data_shapes[i])) + else: + raise ValueError('symbol must have at least one inputs.') mod = Module(symbol=sym, context=ctx, data_names=data_names, label_names=None) mod.bind(for_training=False, data_shapes=data_shapes_) @@ -732,53 +759,111 @@ def static_net_forward(sym, arg_params, aux_params, collector, % num_batches) return collector, data_names -def save_params(fname, arg_params, aux_params, logger=logging): - save_dict = {('arg:%s' % k): v.as_in_context(cpu()) - for k, v in arg_params.items()} - save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) - for k, v in aux_params.items()}) - nd_save(fname, save_dict) - def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_data=None, data_shapes=None, calib_mode='none', ctx=cpu(), logger=logging): - """quantize a gluon net + """User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + The quantization implementation adopts the TensorFlow's approach: + https://www.tensorflow.org/performance/quantization. + The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + and adapts the method to MXNet. + + Parameters + ---------- + network : Gluon HybridBlock + Defines the structure of a neural network for FP32 data types. + quantized_dtype : str + The quantized destination type for input data. Currently support 'int8' + , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. + exclude_layers : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. + calib_data : list + A list containing several batches of input data for calibration. + data_shapes : list + A list containing shapes of input data for symbolic bind. + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + calib_layer : function + Given a layer's output name in string, return True or False for deciding whether to + calibrate this layer. If yes, the statistics of the layer's output will be collected; + otherwise, no information of the layer's output will be collected. If not provided, + all the layers' outputs that need requantization will be collected. + ctx : Context + Defines the device that users want to run forward propagation on the calibration + dataset for collecting layer output statistics. Currently, only supports single context. + logger : Object + A logging object for printing information during the process of quantization. + + Returns + ------- + network : Gluon SymbolBlock + Defines the structure of a neural network for INT8 data types. + ------- """ - logger.info('Export symbolblock') + logger.info('Export HybridBlock') network.hybridize() import mxnet as mx data_nd = [] for shape in data_shapes: data_nd.append(mx.nd.zeros(shape)) network(*data_nd) + import tempfile with tempfile.TemporaryDirectory() as tmpdirname: prefix = os.path.join(tmpdirname, 'tmp') network.export(prefix, epoch=0) symnet, args, auxs = mx.model.load_checkpoint(prefix, 0) - logger.info('Exclude all fc layers') + + logger.info('Exclude all FullyConnect and Flatten layers.') if exclude_layers is None: exclude_layers = [] for layers in list(symnet.get_internals()): - if layers.name.find('dense') != -1 or layers.name.find('flatten') != -1: + if layers.name.find('dense') != -1\ + or layers.name.find('fc') != -1\ + or layers.name.find('flatten') != -1: exclude_layers.append(layers.name) - symnet = symnet.get_backend_symbol('MKLDNN') - qsym, qarg_params, aux_params, collector = quantize_graph(sym=symnet, arg_params=args, aux_params=auxs, - excluded_sym_names=exclude_layers, calib_mode=calib_mode, - calib_layer=None, quantized_dtype=quantized_dtype, logger=logger) + + if ctx == mx.cpu(): + symnet = symnet.get_backend_symbol('MKLDNN_QUANTIZE') + + qsym, qarg_params, aux_params, collector = quantize_graph( + sym=symnet, arg_params=args, aux_params=auxs, excluded_sym_names=exclude_layers, + calib_mode=calib_mode, calib_layer=None, quantized_dtype=quantized_dtype, logger=logger) if calib_mode is not None and calib_mode != 'none': + if not isinstance(ctx, Context): + raise ValueError( + 'currently only supports single ctx, while received %s' % str(ctx)) + if calib_data is None: + raise ValueError( + 'calib_data must be provided when calib_mode=%s' % calib_mode) if calib_mode in ['naive', 'entropy']: - collector, data_names = static_net_forward(sym=symnet, arg_params=args, aux_params=auxs, - collector=collector, calib_data=calib_data, - data_shapes=data_shapes, ctx=cpu(), logger=logger) - qsym, qarg_params, aux_params = calib_graph(qsym=qsym, arg_params=args, aux_params=auxs, - collector=collector, calib_mode=calib_mode, - quantized_dtype=quantized_dtype, logger=logger) - logger.info('Calibrating quantized symbol') + collector, data_names = static_net_forward( + sym=symnet, arg_params=args, aux_params=auxs, collector=collector, + calib_data=calib_data, data_shapes=data_shapes, ctx=cpu(), + logger=logger) + qsym, qarg_params, aux_params = calib_graph( + qsym=qsym, arg_params=args, aux_params=auxs, collector=collector, + calib_mode=calib_mode, quantized_dtype=quantized_dtype, logger=logger) else: - raise ValueError('please set calibration mode to naive or entropy.') - qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') + raise ValueError( + 'please set calibration mode to naive or entropy.') + + if ctx == mx.cpu(): + qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') from ..gluon import SymbolBlock data_sym = [] @@ -788,6 +873,10 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_dat with tempfile.TemporaryDirectory() as tmpdirname: prefix = os.path.join(tmpdirname, 'tmp') param_name = '%s-%04d.params' % (prefix + 'net-quantized', 0) - save_params(param_name, qarg_params, aux_params, logger) + save_dict = {('arg:%s' % k): v.as_in_context(cpu()) + for k, v in qarg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) + for k, v in aux_params.items()}) + nd_save(param_name, save_dict) net.collect_params().load(param_name) return net From 4bc77443d9768a490f5dcc8ccf68424a53815623 Mon Sep 17 00:00:00 2001 From: chenxiny Date: Thu, 6 Jun 2019 15:40:12 +0800 Subject: [PATCH 05/17] [TODO]split to another PR --- src/imperative/imperative_utils.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 5c9706834b2d..185b4af9172c 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -595,7 +595,10 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, *contain_unknown = false; } nnvm::Graph& g = *p_g; - if (g.attrs.count("shape")) { + if (use_inputs) { + if (g.attrs.count("shape_inputs") && g.GetAttr("shape_inputs") == shapes) + return true; + } else if (g.attrs.count("shape")) { const auto& prev_shapes = g.GetAttr("shape"); if (prev_shapes.size() == shapes.size()) { bool match = true; From 1f8c22dfb7c5aad804abbe169e8842c700394324 Mon Sep 17 00:00:00 2001 From: chenxiny Date: Thu, 6 Jun 2019 15:50:35 +0800 Subject: [PATCH 06/17] enhance example script --- .../quantization/imagenet_gen_qsym_mkldnn.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 88def06c2268..a4a7aaebaa17 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -173,6 +173,8 @@ def save_params(fname, arg_params, aux_params, logger=None): 'mobilenetv2_1.0', 'inceptionv3']: logger.info('model %s is converted from GluonCV' % args.model) + rgb_mean = '123.68,116.779,103.939' + rgb_std = '58.393, 57.12, 57.375' args.use_gluon_model = True if args.use_gluon_model == True: prefix = convert_from_gluon(model_name=args.model, image_shape=args.image_shape, classes=1000, logger=logger) @@ -225,32 +227,22 @@ def save_params(fname, arg_params, aux_params, logger=None): if exclude_first_conv: excluded_sym_names += ['conv_1'] elif args.model in ['resnet18_v1', 'resnet50_v1', 'resnet101_v1']: - rgb_mean = '123.68,116.779,103.939' - rgb_std = '58.393, 57.12, 57.375' if exclude_first_conv: excluded_sym_names += ['resnetv10_conv0_fwd'] elif args.model == 'squeezenet1.0': - rgb_mean = '123.68,116.779,103.939' - rgb_std = '58.393, 57.12, 57.375' excluded_sym_names += ['squeezenet0_flatten0_flatten0'] if exclude_first_conv: excluded_sym_names += ['squeezenet0_conv0_fwd'] elif args.model == 'mobilenet1.0': - rgb_mean = '123.68,116.779,103.939' - rgb_std = '58.393, 57.12, 57.375' excluded_sym_names += ['mobilenet0_flatten0_flatten0', 'mobilenet0_pool0_fwd'] if exclude_first_conv: excluded_sym_names += ['mobilenet0_conv0_fwd'] elif args.model == 'mobilenetv2_1.0': - rgb_mean = '123.68,116.779,103.939' - rgb_std = '58.393, 57.12, 57.375' excluded_sym_names += ['mobilenetv20_output_flatten0_flatten0'] if exclude_first_conv: excluded_sym_names += ['mobilenetv20_conv0_fwd'] elif args.model == 'inceptionv3': - rgb_mean = '123.68,116.779,103.939' - rgb_std = '58.393, 57.12, 57.375' if exclude_first_conv: excluded_sym_names += ['inception30_conv0_fwd'] elif args.model == 'custom': @@ -285,11 +277,6 @@ def save_params(fname, arg_params, aux_params, logger=None): calib_mode=calib_mode, calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, logger=logger) if calib_mode == 'none': - # logger.info('Quantizing FP32 model %s' % args.model) - # qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, - # ctx=ctx, excluded_sym_names=excluded_sym_names, - # calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, - # logger=logger) sym_name = '%s-symbol.json' % (prefix + '-quantized') else: logger.info('Creating ImageRecordIter for reading calibration dataset') @@ -306,12 +293,6 @@ def save_params(fname, arg_params, aux_params, logger=None): seed=args.shuffle_seed, **combine_mean_std) - # qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, - # ctx=ctx, excluded_sym_names=excluded_sym_names, - # calib_mode=calib_mode, calib_data=data, - # num_calib_examples=num_calib_batches * batch_size, - # calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, - # label_names=(label_name,), logger=logger) mod = mx.mod.Module(symbol=sym, label_names=('softmax_label',), context=ctx) mod.bind(for_training=False, data_shapes=data.provide_data, label_shapes=data.provide_label) mod.set_params(arg_params, aux_params) From df3406db2fd1d19256cb467fd3762ed8734bd616 Mon Sep 17 00:00:00 2001 From: chenxiny Date: Mon, 10 Jun 2019 13:56:52 +0800 Subject: [PATCH 07/17] add wildcard match for exclude layers --- python/mxnet/contrib/quantization.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index dc5c0585dfec..3a290909b4e2 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -759,7 +759,7 @@ def static_net_forward(sym, arg_params, aux_params, collector, % num_batches) return collector, data_names -def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_data=None, +def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_layers_match=None, calib_data=None, data_shapes=None, calib_mode='none', ctx=cpu(), logger=logging): """User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run @@ -780,6 +780,8 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_dat Default value is 'int8'. exclude_layers : list of strings A list of strings representing the names of the symbols that users want to excluding + exclude_layers_match : list of strings + A list of strings wildcard matching the names of the symbols that users want to excluding from being quantized. calib_data : list A list containing several batches of input data for calibration. @@ -827,14 +829,15 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, calib_dat network.export(prefix, epoch=0) symnet, args, auxs = mx.model.load_checkpoint(prefix, 0) - logger.info('Exclude all FullyConnect and Flatten layers.') if exclude_layers is None: exclude_layers = [] - for layers in list(symnet.get_internals()): - if layers.name.find('dense') != -1\ - or layers.name.find('fc') != -1\ - or layers.name.find('flatten') != -1: - exclude_layers.append(layers.name) + if exclude_layers_match is None: + exclude_layers_match = [] + for name_match in exclude_layers_match: + for layers in list(symnet.get_internals()): + if layers.name.find(name_match) != -1: + exclude_layers.append(layers.name) + logger.info('These layers have been excluded %s' % exclude_layers) if ctx == mx.cpu(): symnet = symnet.get_backend_symbol('MKLDNN_QUANTIZE') From 1279bbd82dd6ace6e9aeb7411d98c6b4861e3456 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 24 Jun 2019 14:29:21 +0000 Subject: [PATCH 08/17] support int8 dtype parameter --- python/mxnet/contrib/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 3a290909b4e2..0b6f61d9c2be 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -881,5 +881,5 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) nd_save(param_name, save_dict) - net.collect_params().load(param_name) + net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved') return net From c3f4968ffe3b6dc405004b0889610dd2cc11b46c Mon Sep 17 00:00:00 2001 From: chenxiny Date: Thu, 4 Jul 2019 11:01:26 +0800 Subject: [PATCH 09/17] enable dataiter api --- python/mxnet/contrib/quantization.py | 128 +++++++++++++-------------- 1 file changed, 62 insertions(+), 66 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 0b6f61d9c2be..35ca40465db2 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -38,7 +38,7 @@ from ..ndarray import save as nd_save from ..ndarray import zeros as zeros from ..ndarray import NDArray -from ..io import DataIter +from ..io import DataIter, DataDesc, DataBatch from ..context import cpu, Context from ..module import Module @@ -423,6 +423,42 @@ def _load_params(params, logger=logging): raise ValueError('Unsupported params provided. Must be either a path to the param file or' ' a pair of dictionaries representing arg_params and aux_params') +class _DataIterWrapper(DataIter): + """DataIter wrapper for general iterator, e.g., gluon dataloader""" + def __init__(self, calib_data): + self._data = calib_data + try: + calib_iter = iter(calib_data) + except TypeError as e: + raise TypeError('calib_data is not a valid iterator. {}'.format(str(e))) + data_example = next(calib_iter) + if isinstance(data_example, (list, tuple)): + data_example = list(data_example) + else: + data_example = [data_example] + # suppose there must be one label in data_example + num_data = len(data_example) - 1 + assert num_data > 0 + self.provide_data = [DataDesc(name='data', shape=(data_example[0].shape))] + self.provide_data += [DataDesc(name='data{}'.format(i), shape=x.shape) for i, x in enumerate(data_example[1:num_data])] + self.batch_size = data_example[0].shape[0] + self.reset() + + def reset(self): + self._iter = iter(self._data) + + def next(self): + return DataBatch(data=next(self._iter)) + +def _as_data_iter(calib_data): + """Convert normal iterator to mx.io.DataIter while parsing the data_shapes""" + if isinstance(calib_data, DataIter): + # already validated DataIter, just return + return calib_data, calib_data.provide_data + + calib_data = _DataIterWrapper(calib_data) + return calib_data, calib_data.provide_data + def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', @@ -703,64 +739,8 @@ def calib_graph(qsym, arg_params, aux_params, collector, return qsym, qarg_params, aux_params -def static_net_forward(sym, arg_params, aux_params, collector, - calib_data, data_shapes, ctx=cpu(), logger=logging): - """Symbolic forward a symbol converted from gluon HybridBlock to use monitor API. - - Parameters - ---------- - sym : str or Symbol - Defines the structure of a neural network for FP32 data types. - arg_params : dict - Dictionary of name to `NDArray`. - aux_params : dict - Dictionary of name to `NDArray`. - collector : function - Layer collector for naive or entropy calibration. - calib_data : list - A list containing several batches of input data for calibration. - data_shapes : list - A list containing shapes of input data for symbolic bind. - logger : Object - A logging object for printing information during the process of quantization. - - Returns - ------- - collector : function - Layer collector containing statistical information of each layer. - data_names : tuple - A tuple containing data names for gluon SymbolBlock import. - ------- - """ - num_batches = len(calib_data) - num_inputs = len(data_shapes) - data_names = () - data_shapes_ = [] - if num_inputs == 1: - data_names = ('data',) - data_shapes_ = [(data_names[0], data_shapes[0])] - elif num_inputs > 1: - for i in range(num_inputs): - data_names = data_names + ('data' + str(i),) - data_shapes_.append((data_names[i], data_shapes[i])) - else: - raise ValueError('symbol must have at least one inputs.') - mod = Module(symbol=sym, context=ctx, - data_names=data_names, label_names=None) - mod.bind(for_training=False, data_shapes=data_shapes_) - mod.set_params(arg_params, aux_params, - allow_missing=False, force_init=True) - mod._exec_group.execs[0].set_monitor_callback( - collector.collect, monitor_all=True) - for batch in calib_data: - mod.forward(data_batch=batch, is_train=False) - if logger is not None: - logger.info("Collected statistics from %d batches" - % num_batches) - return collector, data_names - def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_layers_match=None, calib_data=None, - data_shapes=None, calib_mode='none', ctx=cpu(), logger=logging): + data_shapes=None, calib_mode='none', num_calib_examples=None, ctx=cpu(), logger=logging): """User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -783,10 +763,10 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l exclude_layers_match : list of strings A list of strings wildcard matching the names of the symbols that users want to excluding from being quantized. - calib_data : list - A list containing several batches of input data for calibration. + calib_data : mx.io.DataIter or gluon.DataLoader + A iterable data loading object. data_shapes : list - A list containing shapes of input data for symbolic bind. + List of data_shape, required if calib_data is not provided calib_mode : str If calib_mode='none', no calibration will be used and the thresholds for requantization after the corresponding layers will be calculated at runtime by @@ -802,6 +782,9 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l calibrate this layer. If yes, the statistics of the layer's output will be collected; otherwise, no information of the layer's output will be collected. If not provided, all the layers' outputs that need requantization will be collected. + num_calib_examples : int or None + The maximum number of examples that user would like to use for calibration. If not provided, + the whole calibration dataset will be used. ctx : Context Defines the device that users want to run forward propagation on the calibration dataset for collecting layer output statistics. Currently, only supports single context. @@ -818,9 +801,18 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l logger.info('Export HybridBlock') network.hybridize() import mxnet as mx + if calib_data is not None: + if isinstance(calib_data, DataIter): + dshapes = calib_data.provide_data + else: + calib_data, dshapes = _as_data_iter(calib_data) + if not data_shapes: + data_shapes = dshapes + if not data_shapes: + raise ValueError('data_shapes required') data_nd = [] for shape in data_shapes: - data_nd.append(mx.nd.zeros(shape)) + data_nd.append(mx.nd.zeros(shape.shape)) network(*data_nd) import tempfile @@ -854,10 +846,14 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l raise ValueError( 'calib_data must be provided when calib_mode=%s' % calib_mode) if calib_mode in ['naive', 'entropy']: - collector, data_names = static_net_forward( - sym=symnet, arg_params=args, aux_params=auxs, collector=collector, - calib_data=calib_data, data_shapes=data_shapes, ctx=cpu(), - logger=logger) + data_names = [pair[0] for pair in calib_data.provide_data] + mod = Module(symbol=symnet, context=ctx, + data_names=data_names, label_names=None) + mod.bind(for_training=False, data_shapes=data_shapes) + mod.set_params(args, auxs, allow_missing=False, force_init=True) + num_examples = _collect_layer_statistics(mod, calib_data, collector, num_calib_examples, logger) + logger.info('Collected layer output values from FP32 model using %d examples' + % num_examples) qsym, qarg_params, aux_params = calib_graph( qsym=qsym, arg_params=args, aux_params=auxs, collector=collector, calib_mode=calib_mode, quantized_dtype=quantized_dtype, logger=logger) From b1cb61fc1715f5773bbe20ba812022ccfd559908 Mon Sep 17 00:00:00 2001 From: chenxiny Date: Thu, 4 Jul 2019 14:05:05 +0800 Subject: [PATCH 10/17] use try method --- python/mxnet/contrib/quantization.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 35ca40465db2..3092391b9698 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -437,10 +437,10 @@ def __init__(self, calib_data): else: data_example = [data_example] # suppose there must be one label in data_example - num_data = len(data_example) - 1 + num_data = len(data_example) assert num_data > 0 self.provide_data = [DataDesc(name='data', shape=(data_example[0].shape))] - self.provide_data += [DataDesc(name='data{}'.format(i), shape=x.shape) for i, x in enumerate(data_example[1:num_data])] + self.provide_data += [DataDesc(name='data{}'.format(i), shape=x.shape) for i, x in enumerate(data_example[1:])] self.batch_size = data_example[0].shape[0] self.reset() @@ -813,7 +813,15 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l data_nd = [] for shape in data_shapes: data_nd.append(mx.nd.zeros(shape.shape)) - network(*data_nd) + while True: + try: + network(*data_nd) + except: + del data_nd[-1] + del calib_data.provide_data[-1] + continue + else: + break import tempfile with tempfile.TemporaryDirectory() as tmpdirname: From 0ec3b6861b133b95d1b1efcc260917f0de81e5de Mon Sep 17 00:00:00 2001 From: xinyu Date: Mon, 5 Aug 2019 15:02:13 +0800 Subject: [PATCH 11/17] add unit test for quantize gluon --- python/mxnet/contrib/quantization.py | 8 +++- tests/python/mkl/test_quantization_mkldnn.py | 5 +- .../python/quantization/test_quantization.py | 48 +++++++++++++++++++ 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 0470098bff25..9c4609cd9123 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -847,7 +847,7 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l calib_data : mx.io.DataIter or gluon.DataLoader A iterable data loading object. data_shapes : list - List of data_shape, required if calib_data is not provided + List of DataDesc, required if calib_data is not provided calib_mode : str If calib_mode='none', no calibration will be used and the thresholds for requantization after the corresponding layers will be calculated at runtime by @@ -940,7 +940,8 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l data_names=data_names, label_names=None) mod.bind(for_training=False, data_shapes=data_shapes) mod.set_params(args, auxs, allow_missing=False, force_init=True) - num_examples = _collect_layer_statistics(mod, calib_data, collector, num_calib_examples, logger) + num_examples = _collect_layer_statistics(mod, calib_data, collector, + num_calib_examples, logger) logger.info('Collected layer output values from FP32 model using %d examples' % num_examples) qsym, qarg_params, aux_params = calib_graph( @@ -949,6 +950,8 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l else: raise ValueError( 'please set calibration mode to naive or entropy.') + elif calib_mode is not None and calib_mode == 'none': + data_names = [pair[0] for pair in data_shapes] if ctx == mx.cpu(): qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') @@ -967,4 +970,5 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l for k, v in aux_params.items()}) nd_save(param_name, save_dict) net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved') + net.collect_params().reset_ctx(ctx) return net diff --git a/tests/python/mkl/test_quantization_mkldnn.py b/tests/python/mkl/test_quantization_mkldnn.py index 3c8cee465ec5..8ba2f2b01feb 100644 --- a/tests/python/mkl/test_quantization_mkldnn.py +++ b/tests/python/mkl/test_quantization_mkldnn.py @@ -27,6 +27,5 @@ if __name__ == '__main__': import nose nose.runmodule() - -del os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST'] -del os.environ['MXNET_SUBGRAPH_BACKEND'] + del os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST'] + del os.environ['MXNET_SUBGRAPH_BACKEND'] diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 294e10763220..a991417b28c9 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -21,6 +21,7 @@ import os import mxnet as mx import numpy as np +from mxnet.gluon.model_zoo import vision from mxnet.test_utils import assert_almost_equal, assert_exception, rand_ndarray, rand_shape_nd, same, DummyIter from common import with_seed from mxnet.module import Module @@ -898,6 +899,53 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N for qdtype in ['int8', 'uint8']: check_quantize_model(qdtype) +@with_seed() +def test_quantize_gluon_with_forward(): + def check_quantize_net(qdtype): + if is_test_for_native_cpu(): + print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet') + return + + data_shape = (32, 3, 224, 224) + data_shapes = [mx.io.DataDesc(name='data', shape=data_shape)] + label_shape = (32, 1) + batch_size = 1 + resnet18_v1 = vision.resnet18_v1(pretrained=True) + resnet18_v1.collect_params().reset_ctx(mx.current_context()) + excluded_names_match = [] + if mx.current_context() == mx.gpu(): + excluded_names_match += ['activation', 'relu', 'conv0'] + num_calib_examples = 5 + + random_data = mx.random.uniform(shape=data_shape) + random_label = mx.random.uniform(shape=label_shape) + dataset = mx.gluon.data.dataset.ArrayDataset(random_data, random_label) + calib_data = mx.gluon.data.DataLoader(dataset, batch_size=batch_size) + + quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype, + exclude_layers=None, + exclude_layers_match=excluded_names_match, + calib_mode='none', + data_shapes=data_shapes, + ctx=mx.current_context()) + quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True) + quantized_resnet18_v1(random_data) + + quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype, + exclude_layers=None, + exclude_layers_match=excluded_names_match, + calib_data=calib_data, + calib_mode='naive', + num_calib_examples=num_calib_examples, + ctx=mx.current_context()) + quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True) + quantized_resnet18_v1(random_data) + + for qdtype in ['int8', 'uint8']: + check_quantize_net(qdtype) @with_seed() def test_quantize_sym_with_calib(): From 3e8e5cfd0aa6c670080c881030142fb5bec881a6 Mon Sep 17 00:00:00 2001 From: xinyu Date: Mon, 5 Aug 2019 15:41:50 +0800 Subject: [PATCH 12/17] fix lint --- python/mxnet/contrib/quantization.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 9c4609cd9123..865feba183fe 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -32,11 +32,9 @@ from ..base import NDArrayHandle, SymbolHandle from ..symbol import Symbol from ..symbol import load as sym_load -from ..symbol import load_json as sym_load_json from .. import ndarray from ..ndarray import load as nd_load from ..ndarray import save as nd_save -from ..ndarray import zeros as zeros from ..ndarray import NDArray from ..io import DataIter, DataDesc, DataBatch from ..context import cpu, Context @@ -965,9 +963,9 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l prefix = os.path.join(tmpdirname, 'tmp') param_name = '%s-%04d.params' % (prefix + 'net-quantized', 0) save_dict = {('arg:%s' % k): v.as_in_context(cpu()) - for k, v in qarg_params.items()} + for k, v in qarg_params.items()} save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) - for k, v in aux_params.items()}) + for k, v in aux_params.items()}) nd_save(param_name, save_dict) net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved') net.collect_params().reset_ctx(ctx) From c5198b3857f141327c385df9ad5716953019a0dc Mon Sep 17 00:00:00 2001 From: xinyu Date: Mon, 5 Aug 2019 16:06:47 +0800 Subject: [PATCH 13/17] fix lint 2 --- python/mxnet/contrib/quantization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 865feba183fe..3f1c3ed2f132 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -421,6 +421,7 @@ def _load_params(params, logger=logging): raise ValueError('Unsupported params provided. Must be either a path to the param file or' ' a pair of dictionaries representing arg_params and aux_params') +# pylint: disable=super-init-not-called class _DataIterWrapper(DataIter): """DataIter wrapper for general iterator, e.g., gluon dataloader""" def __init__(self, calib_data): @@ -447,6 +448,7 @@ def reset(self): def next(self): return DataBatch(data=next(self._iter)) +# pylint: enable=super-init-not-called def _as_data_iter(calib_data): """Convert normal iterator to mx.io.DataIter while parsing the data_shapes""" @@ -895,7 +897,7 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l while True: try: network(*data_nd) - except: + except TypeError: del data_nd[-1] del calib_data.provide_data[-1] continue From aac1d05371ddf328ebacc93a2499210e96cacb31 Mon Sep 17 00:00:00 2001 From: xinyu Date: Mon, 5 Aug 2019 22:51:03 +0800 Subject: [PATCH 14/17] fix temporary directory in python2 --- python/mxnet/contrib/quantization.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 3f1c3ed2f132..daebd3cc3b3e 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -905,7 +905,21 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l break import tempfile - with tempfile.TemporaryDirectory() as tmpdirname: + try: + from tempfile import TemporaryDirectory + except: + # really simple implementation of TemporaryDirectory + class TemporaryDirectory(object): + def __init__(self, suffix='', prefix='', dir=''): + self._dirname = tempfile.mkdtemp(suffix, prefix, dir) + + def __enter__(self): + return self._dirname + + def __exit__(self, exc_type, exc_value, traceback): + shutil.rmtree(self._dirname) + + with TemporaryDirectory() as tmpdirname: prefix = os.path.join(tmpdirname, 'tmp') network.export(prefix, epoch=0) symnet, args, auxs = mx.model.load_checkpoint(prefix, 0) @@ -961,7 +975,7 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l for name in data_names: data_sym.append(mx.sym.var(name)) net = SymbolBlock(qsym, data_sym) - with tempfile.TemporaryDirectory() as tmpdirname: + with TemporaryDirectory() as tmpdirname: prefix = os.path.join(tmpdirname, 'tmp') param_name = '%s-%04d.params' % (prefix + 'net-quantized', 0) save_dict = {('arg:%s' % k): v.as_in_context(cpu()) From 5593519ecaa9852253f16d343329cc51f99c9120 Mon Sep 17 00:00:00 2001 From: xinyu Date: Mon, 5 Aug 2019 22:52:49 +0800 Subject: [PATCH 15/17] fix lint --- python/mxnet/contrib/quantization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index daebd3cc3b3e..7899a703ef70 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -26,6 +26,7 @@ import ctypes import logging import os +import shutil import numpy as np from ..base import _LIB, check_call, py_str from ..base import c_array, c_str, mx_uint, c_str_array @@ -907,7 +908,7 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l import tempfile try: from tempfile import TemporaryDirectory - except: + except AttributeError: # really simple implementation of TemporaryDirectory class TemporaryDirectory(object): def __init__(self, suffix='', prefix='', dir=''): From 375d2b572bcd579cb6a78c8e10fd24928b288007 Mon Sep 17 00:00:00 2001 From: xinyu Date: Tue, 6 Aug 2019 09:49:47 +0800 Subject: [PATCH 16/17] fix try import and add todo --- python/mxnet/contrib/quantization.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 7899a703ef70..262cc075bc89 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -908,7 +908,7 @@ def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_l import tempfile try: from tempfile import TemporaryDirectory - except AttributeError: + except ImportError: # really simple implementation of TemporaryDirectory class TemporaryDirectory(object): def __init__(self, suffix='', prefix='', dir=''): @@ -919,7 +919,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): shutil.rmtree(self._dirname) - + # TODO(xinyu-intel): tmp solution to save and reload for mxnet.mod.Module. + # will enhance `export` function to return `sym, args, auxs` directly. with TemporaryDirectory() as tmpdirname: prefix = os.path.join(tmpdirname, 'tmp') network.export(prefix, epoch=0) @@ -976,6 +977,8 @@ def __exit__(self, exc_type, exc_value, traceback): for name in data_names: data_sym.append(mx.sym.var(name)) net = SymbolBlock(qsym, data_sym) + # TODO(xinyu-intel): tmp solution to save param_dict and reload for SymbolBlock + # will enhance SymbolBlock to load args, auxs directly. with TemporaryDirectory() as tmpdirname: prefix = os.path.join(tmpdirname, 'tmp') param_name = '%s-%04d.params' % (prefix + 'net-quantized', 0) From 6117167662b61624817c527341a02a28b6e85286 Mon Sep 17 00:00:00 2001 From: xinyu Date: Wed, 7 Aug 2019 10:09:10 +0800 Subject: [PATCH 17/17] trigger