Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add quantization support for GluonCV #15754

Merged
merged 21 commits into from
Aug 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 211 additions & 1 deletion python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ctypes
import logging
import os
import shutil
import numpy as np
from ..base import _LIB, check_call, py_str
from ..base import c_array, c_str, mx_uint, c_str_array
Expand All @@ -34,8 +35,9 @@
from ..symbol import load as sym_load
from .. import ndarray
from ..ndarray import load as nd_load
from ..ndarray import save as nd_save
from ..ndarray import NDArray
from ..io import DataIter
from ..io import DataIter, DataDesc, DataBatch
from ..context import cpu, Context
from ..module import Module

Expand Down Expand Up @@ -420,6 +422,44 @@ def _load_params(params, logger=logging):
raise ValueError('Unsupported params provided. Must be either a path to the param file or'
' a pair of dictionaries representing arg_params and aux_params')

# pylint: disable=super-init-not-called
class _DataIterWrapper(DataIter):
"""DataIter wrapper for general iterator, e.g., gluon dataloader"""
def __init__(self, calib_data):
self._data = calib_data
try:
calib_iter = iter(calib_data)
except TypeError as e:
raise TypeError('calib_data is not a valid iterator. {}'.format(str(e)))
data_example = next(calib_iter)
if isinstance(data_example, (list, tuple)):
data_example = list(data_example)
else:
data_example = [data_example]
# suppose there must be one label in data_example
num_data = len(data_example)
assert num_data > 0
self.provide_data = [DataDesc(name='data', shape=(data_example[0].shape))]
self.provide_data += [DataDesc(name='data{}'.format(i), shape=x.shape) for i, x in enumerate(data_example[1:])]
self.batch_size = data_example[0].shape[0]
self.reset()

def reset(self):
self._iter = iter(self._data)

def next(self):
return DataBatch(data=next(self._iter))
# pylint: enable=super-init-not-called

def _as_data_iter(calib_data):
"""Convert normal iterator to mx.io.DataIter while parsing the data_shapes"""
if isinstance(calib_data, DataIter):
# already validated DataIter, just return
return calib_data, calib_data.provide_data

calib_data = _DataIterWrapper(calib_data)
return calib_data, calib_data.provide_data

def quantize_model(sym, arg_params, aux_params,
data_names=('data',), label_names=('softmax_label',),
ctx=cpu(), excluded_sym_names=None, calib_mode='entropy',
Expand Down Expand Up @@ -780,3 +820,173 @@ def calib_graph(qsym, arg_params, aux_params, collector,
qarg_params = _quantize_params(qsym, arg_params, th_dict)

return qsym, qarg_params, aux_params

def quantize_net(network, quantized_dtype='auto', exclude_layers=None, exclude_layers_match=None, calib_data=None,
data_shapes=None, calib_mode='none', num_calib_examples=None, ctx=cpu(), logger=logging):
"""User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
inference using the quantized models on Windows for now.
The quantization implementation adopts the TensorFlow's approach:
https://www.tensorflow.org/performance/quantization.
The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT:
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
and adapts the method to MXNet.

Parameters
----------
network : Gluon HybridBlock
Defines the structure of a neural network for FP32 data types.
quantized_dtype : str
The quantized destination type for input data. Currently support 'int8'
, 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result.
Default value is 'int8'.
exclude_layers : list of strings
A list of strings representing the names of the symbols that users want to excluding
exclude_layers_match : list of strings
A list of strings wildcard matching the names of the symbols that users want to excluding
from being quantized.
calib_data : mx.io.DataIter or gluon.DataLoader
A iterable data loading object.
data_shapes : list
List of DataDesc, required if calib_data is not provided
calib_mode : str
If calib_mode='none', no calibration will be used and the thresholds for
requantization after the corresponding layers will be calculated at runtime by
calling min and max operators. The quantized models generated in this
mode are normally 10-20% slower than those with calibrations during inference.
If calib_mode='naive', the min and max values of the layer outputs from a calibration
dataset will be directly taken as the thresholds for quantization.
If calib_mode='entropy' (default mode), the thresholds for quantization will be
derived such that the KL divergence between the distributions of FP32 layer outputs and
quantized layer outputs is minimized based upon the calibration dataset.
calib_layer : function
Given a layer's output name in string, return True or False for deciding whether to
calibrate this layer. If yes, the statistics of the layer's output will be collected;
otherwise, no information of the layer's output will be collected. If not provided,
all the layers' outputs that need requantization will be collected.
num_calib_examples : int or None
The maximum number of examples that user would like to use for calibration. If not provided,
the whole calibration dataset will be used.
ctx : Context
Defines the device that users want to run forward propagation on the calibration
dataset for collecting layer output statistics. Currently, only supports single context.
logger : Object
A logging object for printing information during the process of quantization.

Returns
-------
network : Gluon SymbolBlock
Defines the structure of a neural network for INT8 data types.
-------
"""

logger.info('Export HybridBlock')
network.hybridize()
import mxnet as mx
if calib_data is not None:
if isinstance(calib_data, DataIter):
dshapes = calib_data.provide_data
else:
calib_data, dshapes = _as_data_iter(calib_data)
if not data_shapes:
data_shapes = dshapes
if not data_shapes:
raise ValueError('data_shapes required')
data_nd = []
for shape in data_shapes:
data_nd.append(mx.nd.zeros(shape.shape))
while True:
try:
network(*data_nd)
except TypeError:
del data_nd[-1]
del calib_data.provide_data[-1]
continue
else:
break

import tempfile
try:
from tempfile import TemporaryDirectory
except ImportError:
# really simple implementation of TemporaryDirectory
class TemporaryDirectory(object):
def __init__(self, suffix='', prefix='', dir=''):
self._dirname = tempfile.mkdtemp(suffix, prefix, dir)

def __enter__(self):
return self._dirname

def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self._dirname)
# TODO(xinyu-intel): tmp solution to save and reload for mxnet.mod.Module.
# will enhance `export` function to return `sym, args, auxs` directly.
with TemporaryDirectory() as tmpdirname:
prefix = os.path.join(tmpdirname, 'tmp')
network.export(prefix, epoch=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it's more convenient to allow export return (sym, arg_params, aux_params) so we can avoid using temporary file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no mandatory and urgent now, but gluoncv also has the same implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, will enhance them later to make it more easier to switch between symbolic and block:)

symnet, args, auxs = mx.model.load_checkpoint(prefix, 0)

if exclude_layers is None:
exclude_layers = []
if exclude_layers_match is None:
exclude_layers_match = []
for name_match in exclude_layers_match:
for layers in list(symnet.get_internals()):
if layers.name.find(name_match) != -1:
exclude_layers.append(layers.name)
logger.info('These layers have been excluded %s' % exclude_layers)

if ctx == mx.cpu():
symnet = symnet.get_backend_symbol('MKLDNN_QUANTIZE')

qsym, qarg_params, aux_params, collector = quantize_graph(
sym=symnet, arg_params=args, aux_params=auxs, excluded_sym_names=exclude_layers,
calib_mode=calib_mode, calib_layer=None, quantized_dtype=quantized_dtype, logger=logger)

if calib_mode is not None and calib_mode != 'none':
if not isinstance(ctx, Context):
raise ValueError(
'currently only supports single ctx, while received %s' % str(ctx))
if calib_data is None:
raise ValueError(
'calib_data must be provided when calib_mode=%s' % calib_mode)
if calib_mode in ['naive', 'entropy']:
data_names = [pair[0] for pair in calib_data.provide_data]
mod = Module(symbol=symnet, context=ctx,
data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes)
mod.set_params(args, auxs, allow_missing=False, force_init=True)
num_examples = _collect_layer_statistics(mod, calib_data, collector,
num_calib_examples, logger)
logger.info('Collected layer output values from FP32 model using %d examples'
% num_examples)
qsym, qarg_params, aux_params = calib_graph(
qsym=qsym, arg_params=args, aux_params=auxs, collector=collector,
calib_mode=calib_mode, quantized_dtype=quantized_dtype, logger=logger)
else:
raise ValueError(
'please set calibration mode to naive or entropy.')
elif calib_mode is not None and calib_mode == 'none':
data_names = [pair[0] for pair in data_shapes]

if ctx == mx.cpu():
qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE')

from ..gluon import SymbolBlock
data_sym = []
for name in data_names:
data_sym.append(mx.sym.var(name))
net = SymbolBlock(qsym, data_sym)
# TODO(xinyu-intel): tmp solution to save param_dict and reload for SymbolBlock
# will enhance SymbolBlock to load args, auxs directly.
with TemporaryDirectory() as tmpdirname:
prefix = os.path.join(tmpdirname, 'tmp')
param_name = '%s-%04d.params' % (prefix + 'net-quantized', 0)
save_dict = {('arg:%s' % k): v.as_in_context(cpu())
for k, v in qarg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(cpu())
for k, v in aux_params.items()})
nd_save(param_name, save_dict)
net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved')
net.collect_params().reset_ctx(ctx)
return net
5 changes: 2 additions & 3 deletions tests/python/mkl/test_quantization_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,5 @@
if __name__ == '__main__':
import nose
nose.runmodule()

del os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST']
del os.environ['MXNET_SUBGRAPH_BACKEND']
del os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST']
del os.environ['MXNET_SUBGRAPH_BACKEND']
48 changes: 48 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import mxnet as mx
import numpy as np
from mxnet.gluon.model_zoo import vision
from mxnet.test_utils import assert_almost_equal, assert_exception, rand_ndarray, rand_shape_nd, same, DummyIter
from common import with_seed
from mxnet.module import Module
Expand Down Expand Up @@ -898,6 +899,53 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N
for qdtype in ['int8', 'uint8']:
check_quantize_model(qdtype)

@with_seed()
def test_quantize_gluon_with_forward():
def check_quantize_net(qdtype):
if is_test_for_native_cpu():
print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet')
return

data_shape = (32, 3, 224, 224)
data_shapes = [mx.io.DataDesc(name='data', shape=data_shape)]
label_shape = (32, 1)
batch_size = 1
resnet18_v1 = vision.resnet18_v1(pretrained=True)
resnet18_v1.collect_params().reset_ctx(mx.current_context())
excluded_names_match = []
if mx.current_context() == mx.gpu():
excluded_names_match += ['activation', 'relu', 'conv0']
num_calib_examples = 5

random_data = mx.random.uniform(shape=data_shape)
random_label = mx.random.uniform(shape=label_shape)
dataset = mx.gluon.data.dataset.ArrayDataset(random_data, random_label)
calib_data = mx.gluon.data.DataLoader(dataset, batch_size=batch_size)

quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype,
exclude_layers=None,
exclude_layers_match=excluded_names_match,
calib_mode='none',
data_shapes=data_shapes,
ctx=mx.current_context())
quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True)
quantized_resnet18_v1(random_data)

quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype,
exclude_layers=None,
exclude_layers_match=excluded_names_match,
calib_data=calib_data,
calib_mode='naive',
num_calib_examples=num_calib_examples,
ctx=mx.current_context())
quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True)
quantized_resnet18_v1(random_data)

for qdtype in ['int8', 'uint8']:
check_quantize_net(qdtype)

@with_seed()
def test_quantize_sym_with_calib():
Expand Down