Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Implement mkldnn convolution fusion and quantization. (#12530)
Browse files Browse the repository at this point in the history
* Implement mkldnn convolution fusion.
Implement mkldnn convolution quantization.

* Fix lint

* Fix performance regression caused by mkldnn fallback.

* clean up include

* Fix msbuild on openmp pragma.

* Fix quantization test, allow to use original op names as exclude layer for quantization.

* Fix unittest.

* Fix unittest

* fix lint

* Add post quantize fusion

* add test case

* add head license in test case

* Remove GetBoolHash()

* Remove mkldnn fallback change.

* Address Haibin's comments.

* Add TIsMKLDNN for _sg_mkldnn_conv temporarily.

* Address reminisce's comments.

* Handle the case that inplace fail.

* pass unit test.

* Add symbol api get_backend_symbol()

* Retrigger ci

* update the test case

* Check subgraph index.

* Use index as FAvoidQuantizeInput's parameter.

* Add mkldnn_hwigo support as quantizaiton needs.

* Address KellenSunderland's comments.

* Handle input order change after subgraph pass.

* Fix ci test
  • Loading branch information
ZhennanQin authored and eric-haibin-lin committed Oct 9, 2018
1 parent f0140b3 commit ad027ca
Show file tree
Hide file tree
Showing 29 changed files with 2,416 additions and 216 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ $(warning "USE_MKL2017 is deprecated. We will switch to USE_MKLDNN.")
endif

ifeq ($(USE_MKLDNN), 1)
MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/install
MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/install
MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install
MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install
export USE_MKLML = 1
endif

Expand Down
2 changes: 1 addition & 1 deletion example/quantization/imagenet_gen_qsym.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
' thresholds. This mode is expected to produce the best inference accuracy of all three'
' kinds of quantized models if the calibration dataset is representative enough of the'
' inference dataset.')
parser.add_argument('--quantized-dtype', type=str, default='int8',
parser.add_argument('--quantized-dtype', type=str, default='int8',
choices=['int8', 'uint8'],
help='quantization destination data type for input data')
args = parser.parse_args()
Expand Down
207 changes: 207 additions & 0 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import os
import logging
from common import modelzoo
import mxnet as mx
from mxnet.contrib.quantization import *
from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
import ctypes


def download_calib_dataset(dataset_url, calib_dataset, logger=None):
if logger is not None:
logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset))
mx.test_utils.download(dataset_url, calib_dataset)


def download_model(model_name, logger=None):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if logger is not None:
logger.info('Downloading model %s... into path %s' % (model_name, model_path))
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))


def save_symbol(fname, sym, logger=None):
if logger is not None:
logger.info('Saving symbol into file at %s' % fname)
sym.save(fname)


def save_params(fname, arg_params, aux_params, logger=None):
if logger is not None:
logger.info('Saving params into file at %s' % fname)
save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()})
mx.nd.save(fname, save_dict)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with MKL-DNN support')
parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--label-name', type=str, default='softmax_label')
parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec',
help='path of the calibration dataset')
parser.add_argument('--image-shape', type=str, default='3,224,224')
parser.add_argument('--data-nthreads', type=int, default=60,
help='number of threads for data decoding')
parser.add_argument('--num-calib-batches', type=int, default=10,
help='number of batches for calibration')
parser.add_argument('--exclude-first-conv', action='store_true', default=True,
help='excluding quantizing the first conv layer since the'
' input data may have negative value which doesn\'t support at moment' )
parser.add_argument('--shuffle-dataset', action='store_true', default=True,
help='shuffle the calibration dataset')
parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
help='shuffling chunk seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--shuffle-seed', type=int, default=48564309,
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--calib-mode', type=str, default='entropy',
help='calibration mode used for generating calibration table for the quantized symbol; supports'
' 1. none: no calibration will be used. The thresholds for quantization will be calculated'
' on the fly. This will result in inference speed slowdown and loss of accuracy'
' in general.'
' 2. naive: simply take min and max values of layer outputs as thresholds for'
' quantization. In general, the inference accuracy worsens with more examples used in'
' calibration. It is recommended to use `entropy` mode as it produces more accurate'
' inference results.'
' 3. entropy: calculate KL divergence of the fp32 output and quantized output for optimal'
' thresholds. This mode is expected to produce the best inference accuracy of all three'
' kinds of quantized models if the calibration dataset is representative enough of the'
' inference dataset.')
parser.add_argument('--quantized-dtype', type=str, default='uint8',
choices=['int8', 'uint8'],
help='quantization destination data type for input data')
parser.add_argument('--enable-calib-quantize', type=bool, default=True,
help='If enabled, the quantize op will '
'be calibrated offline if calibration mode is '
'enabled')
args = parser.parse_args()
ctx = mx.cpu(0)
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)

logger.info('shuffle_dataset=%s' % args.shuffle_dataset)

calib_mode = args.calib_mode
logger.info('calibration mode set to %s' % calib_mode)

# download calibration dataset
if calib_mode != 'none':
download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset)

# download model
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

sym = sym.get_backend_symbol('MKLDNN')

# get batch size
batch_size = args.batch_size
logger.info('batch size = %d for calibration' % batch_size)

# get number of batches for calibration
num_calib_batches = args.num_calib_batches
if calib_mode == 'none':
logger.info('skip calibration step as calib_mode is none')
else:
logger.info('number of batches = %d for calibration' % num_calib_batches)

# get number of threads for decoding the dataset
data_nthreads = args.data_nthreads

# get image shape
image_shape = args.image_shape

exclude_first_conv = args.exclude_first_conv
excluded_sym_names = []
if args.model == 'imagenet1k-resnet-152':
rgb_mean = '0,0,0'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['flatten0', 'fc1']
if exclude_first_conv:
excluded_sym_names += ['conv0', 'pooling0']
elif args.model == 'imagenet1k-inception-bn':
rgb_mean = '123.68,116.779,103.939'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['flatten', 'fc1']
if exclude_first_conv:
excluded_sym_names += ['conv_1']
else:
raise ValueError('model %s is not supported in this script' % args.model)

label_name = args.label_name
logger.info('label_name = %s' % label_name)

data_shape = tuple([int(i) for i in image_shape.split(',')])
logger.info('Input data shape = %s' % str(data_shape))

logger.info('rgb_mean = %s' % rgb_mean)
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}

if calib_mode == 'none':
logger.info('Quantizing FP32 model %s' % args.model)
qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
ctx=ctx, excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, quantized_dtype=args.quantized_dtype,
logger=logger)
sym_name = '%s-symbol.json' % (prefix + '-quantized')
else:
logger.info('Creating ImageRecordIter for reading calibration dataset')
data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=args.shuffle_dataset,
shuffle_chunk_seed=args.shuffle_chunk_seed,
seed=args.shuffle_seed,
**mean_args)

qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
ctx=ctx, excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, calib_data=data,
num_calib_examples=num_calib_batches * batch_size,
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
label_names=(label_name,), calib_quantize_op = True,
logger=logger)
if calib_mode == 'entropy':
suffix = '-quantized-%dbatches-entropy' % num_calib_batches
elif calib_mode == 'naive':
suffix = '-quantized-%dbatches-naive' % num_calib_batches
else:
raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`'
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
save_symbol(sym_name, qsym, logger)
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
2 changes: 1 addition & 1 deletion example/quantization/imagenet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
ctx = mx.cpu(0)
else:
raise ValueError('ctx %s is not supported in this script' % args.ctx)

logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
Expand Down
22 changes: 15 additions & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1542,18 +1542,17 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
* \param sym_handle symbol to be converted
* \param ret_sym_handle quantized symbol result
* \param num_excluded_symbols number of layers excluded from being quantized in the input symbol
* \param excluded_symbols array of symbols to be excluded from being quantized
* \param excluded_symbols op names to be excluded from being quantized
* \param num_offline number of parameters that are quantized offline
* \param offline_params array of c strings representing the names of params quantized offline
* \param quantized_dtype the quantized destination type for input data.
* \param calib_quantize whether calibrate quantize op with offline calibration data.
*/
MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle,
SymbolHandle *ret_sym_handle,
MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle,
const mx_uint num_excluded_symbols,
const SymbolHandle *excluded_symbols,
const mx_uint num_offline,
const char **offline_params,
const char *quantized_dtype);
const char **excluded_symbols,
const mx_uint num_offline, const char **offline_params,
const char *quantized_dtype, const bool calib_quantize);

/*!
* \brief Set calibration table to node attributes in the sym
Expand All @@ -1571,6 +1570,15 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
const float* high_quantiles,
SymbolHandle* ret_sym_handle);

/*!
* \brief Run subgraph pass based on the backend provided
* \param sym_handle symbol to be converted
* \param backend backend names for subgraph pass
* \param ret_sym_handle returned symbol
*/
MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
SymbolHandle *ret_sym_handle);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down
11 changes: 11 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,12 @@ class NDArray {
}

#if MXNET_USE_MKLDNN == 1
/*
* Create NDArray from mkldnn memory.
* mkldnn_mem The mkldnn memory to be managed.
* static_data If true, mkldnn memory won't be freed on destruction.
*/
explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true);
/*
* Test if the data is stored in one of special MKLDNN format.
*/
Expand Down Expand Up @@ -742,6 +748,11 @@ class NDArray {
* It's used by FullyConnected right now.
*/
NDArray MKLDNNDataReshape(const TShape &shape) const;

/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
void UpdateMKLDNNMemDesc();
#endif

/*!
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@ using FQuantizedOp = std::function<nnvm::NodePtr (const NodeAttrs& attrs)>;
*/
using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized. This is usually used for the quantized operators
* which can handle fp32 inputs directly.
*/
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
size_t index)>;

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
2 changes: 1 addition & 1 deletion mkldnn.mk
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ $(MKLDNN_LIBFILE):

mkldnn_clean:
$(RM) -r 3rdparty/mkldnn/build
$(RM) -r 3rdparty/mkldnn/install/*
$(RM) -r $(MKLDNNROOT)

ifeq ($(USE_MKLDNN), 1)
mkldnn: mkldnn_build
Expand Down
Loading

0 comments on commit ad027ca

Please sign in to comment.