From afef734bfa0cfc81ca0882f91ef26f4f715b2ec4 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 31 Jul 2019 15:10:16 +0800 Subject: [PATCH 01/14] Add mkldnn imlementation for quantized flatten and smart quantize mode Change-Id: Id9d7504890852faf4e84fdcd66585d1fd78beeb2 --- example/quantization/README.md | 23 ++- .../quantization/imagenet_gen_qsym_mkldnn.py | 80 +++++------ example/ssd/quantization.py | 5 - include/mxnet/c_api.h | 4 +- include/mxnet/op_attr_types.h | 16 +++ python/mxnet/contrib/quantization.py | 16 ++- src/c_api/c_api_symbolic.cc | 5 +- src/operator/nn/mkldnn/mkldnn_flatten-inl.h | 45 ++++++ src/operator/nn/mkldnn/mkldnn_flatten.cc | 10 +- .../mkldnn/mkldnn_quantized_flatten.cc | 61 ++++++++ .../quantization/quantize_graph_pass.cc | 135 ++++++++++++++---- src/operator/quantization/quantized_conv.cc | 3 + .../quantization/quantized_fully_connected.cc | 3 + src/operator/subgraph/mkldnn/mkldnn_conv.cc | 3 + src/operator/subgraph/mkldnn/mkldnn_fc.cc | 3 + tests/python/mkl/test_subgraph.py | 6 +- .../python/quantization/test_quantization.py | 18 ++- 17 files changed, 328 insertions(+), 108 deletions(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_flatten-inl.h create mode 100644 src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc diff --git a/example/quantization/README.md b/example/quantization/README.md index 1ae58fbb3a69..2215a13a97d5 100644 --- a/example/quantization/README.md +++ b/example/quantization/README.md @@ -86,21 +86,20 @@ Use the following command to install [Gluon-CV](https://gluon-cv.mxnet.io/): pip install gluoncv ``` -Below are some quantization demos. These models have been tested on Linux systems. +The following models have been tested on Linux systems. Accuracy is collected on Intel XEON Cascade Lake CPU. For CPU with Skylake Lake or eariler architecture, the accuracy may not be the same. | Model | Source | Dataset | FP32 Accuracy (top-1/top-5)| INT8 Accuracy (top-1/top-5)| |:---|:---|---|:---:|:---:| -| [ResNet18-V1](#3) | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |70.15%/89.38%|69.92%/89.26%| -| [ResNet50-V1](#3) | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) | 76.34%/93.13% | 75.91%/92.95% | -| [ResNet50-V1b](#3) | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) | 76.82%/93.38% | 76.39%/93.24% | -| [ResNet101-V1](#3) | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) | 77.33%/93.59% | 77.05%/93.43% | -|[Squeezenet 1.0](#4)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|56.98%/79.20%|52.98%/77.21%| -|[MobileNet 1.0](#5)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.23%/90.64%|72.03%/90.42%| -|[MobileNetV2 1.0](#6)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|70.27%/89.62%|69.70%/89.26%| -|[Inception V3](#7)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|77.76%/93.83% |77.87%/93.78% | -|[ResNet152-V2](#8)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|76.65%/93.07%|76.36%/92.89%| -|[Inception-BN](#9)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/inception-bn/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.28%/90.63%|72.20%/90.56%| -| [SSD-VGG16](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) | VOC2007/2012 | 0.8366 mAP | 0.8364 mAP | +| [ResNet18-V1](#3) | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |70.15%/89.38%|69.92%/89.30%| +| [ResNet50-V1](#3) | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) | 76.34%/93.13% | 76.06%/92.99% | +| [ResNet101-V1](#3) | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) | 77.33%/93.59% | 77.07%/93.47% | +|[Squeezenet 1.0](#4)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|56.98%/79.20%|56.79%/79.47%| +|[MobileNet 1.0](#5)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.23%/90.64%|72.06%/90.53%| +|[MobileNetV2 1.0](#6)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|70.27%/89.62%|69.82%/89.35%| +|[Inception V3](#7)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|77.76%/93.83% |78.05%/93.91% | +|[ResNet152-V2](#8)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|76.65%/93.07%|76.19%/92.88%| +|[Inception-BN](#9)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/inception-bn/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.28%/90.63%|72.02%/90.53%| +| [SSD-VGG16](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) | VOC2007/2012 | 0.8366 mAP | 0.8357 mAP | | [SSD-VGG16](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) | COCO2014 | 0.2552 mAP | 0.253 mAP |

ResNetV1

diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 302a04449885..9b69c8088cc0 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -199,49 +199,43 @@ def save_params(fname, arg_params, aux_params, logger=None): logger.info('quantized dtype is set to uint8, will exclude first conv.') exclude_first_conv = True excluded_sym_names = [] - if not args.no_pretrained: - if args.model == 'imagenet1k-resnet-152': - rgb_mean = '0,0,0' - rgb_std = '1,1,1' - excluded_sym_names += ['flatten0'] - if exclude_first_conv: - excluded_sym_names += ['conv0'] - elif args.model == 'imagenet1k-inception-bn': - rgb_mean = '123.68,116.779,103.939' - rgb_std = '1,1,1' - excluded_sym_names += ['flatten'] - if exclude_first_conv: - excluded_sym_names += ['conv_1'] - elif args.model.find('resnet') != -1 and args.model.find('v1') != -1: - if exclude_first_conv: - excluded_sym_names += ['resnetv10_conv0_fwd'] - elif args.model.find('resnet') != -1 and args.model.find('v2') != -1: - excluded_sym_names += ['resnetv20_flatten0_flatten0'] - if exclude_first_conv: - excluded_sym_names += ['resnetv20_conv0_fwd'] - elif args.model.find('vgg') != -1: - if exclude_first_conv: - excluded_sym_names += ['vgg0_conv0_fwd'] - elif args.model.find('squeezenet1') != -1: - excluded_sym_names += ['squeezenet0_flatten0_flatten0'] - if exclude_first_conv: - excluded_sym_names += ['squeezenet0_conv0_fwd'] - elif args.model.find('mobilenet') != -1 and args.model.find('v2') == -1: - excluded_sym_names += ['mobilenet0_flatten0_flatten0', - 'mobilenet0_pool0_fwd'] - if exclude_first_conv: - excluded_sym_names += ['mobilenet0_conv0_fwd'] - elif args.model.find('mobilenet') != -1 and args.model.find('v2') != -1: - excluded_sym_names += ['mobilenetv20_output_flatten0_flatten0'] - if exclude_first_conv: - excluded_sym_names += ['mobilenetv20_conv0_fwd'] - elif args.model == 'inceptionv3': - if exclude_first_conv: - excluded_sym_names += ['inception30_conv0_fwd'] - else: - raise ValueError('Currently, model %s is not supported in this script' % args.model) - else: - logger.info('Please set proper RGB configs for model %s' % args.model) + if args.model == 'imagenet1k-resnet-152': + rgb_mean = '0,0,0' + rgb_std = '1,1,1' + if exclude_first_conv: + excluded_sym_names += ['conv0'] + elif args.model == 'imagenet1k-inception-bn': + rgb_mean = '123.68,116.779,103.939' + rgb_std = '1,1,1' + if exclude_first_conv: + excluded_sym_names += ['conv_1'] + elif args.model in ['resnet18_v1', 'resnet50_v1', 'resnet101_v1']: + rgb_mean = '123.68,116.779,103.939' + rgb_std = '58.393, 57.12, 57.375' + if exclude_first_conv: + excluded_sym_names += ['resnetv10_conv0_fwd'] + elif args.model == 'squeezenet1.0': + rgb_mean = '123.68,116.779,103.939' + rgb_std = '58.393, 57.12, 57.375' + if exclude_first_conv: + excluded_sym_names += ['squeezenet0_conv0_fwd'] + elif args.model == 'mobilenet1.0': + rgb_mean = '123.68,116.779,103.939' + rgb_std = '58.393, 57.12, 57.375' + excluded_sym_names += ['mobilenet0_pool0_fwd'] + if exclude_first_conv: + excluded_sym_names += ['mobilenet0_conv0_fwd'] + elif args.model == 'mobilenetv2_1.0': + rgb_mean = '123.68,116.779,103.939' + rgb_std = '58.393, 57.12, 57.375' + if exclude_first_conv: + excluded_sym_names += ['mobilenetv20_conv0_fwd'] + elif args.model == 'inceptionv3': + rgb_mean = '123.68,116.779,103.939' + rgb_std = '58.393, 57.12, 57.375' + if exclude_first_conv: + excluded_sym_names += ['inception30_conv0_fwd'] + elif args.model == 'custom': # add rgb mean/std of your model. rgb_mean = '0,0,0' rgb_std = '0,0,0' diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py index d50935499240..b92aadcf47ac 100644 --- a/example/ssd/quantization.py +++ b/example/ssd/quantization.py @@ -123,11 +123,6 @@ def calib_layer(name): return not (name.endswith('_data') or exclude_first_conv = args.exclude_first_conv excluded_sym_names = [] rgb_mean = '123,117,104' - for i in range(1,19): - excluded_sym_names += ['flatten'+str(i)] - excluded_sym_names += ['multibox_loc_pred', - 'concat0', - 'concat1'] if exclude_first_conv: excluded_sym_names += ['conv1_1'] diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d4f756f5333c..dfe0a9f98b25 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1907,6 +1907,7 @@ MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym, * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could + * \param quantize_mode quantize mode to be used in quantize pass */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const int* dev_type, @@ -1915,7 +1916,8 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_ha const mx_uint num_excluded_op_names, const char **excluded_op_names, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype, const bool calib_quantize); + const char *quantized_dtype, const bool calib_quantize, + const char *quantize_mode); /*! * \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 889b5028a460..9f4f5d3eadd2 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -131,6 +131,16 @@ enum class DispatchMode { kVariable, }; +/*! \brief the quantization type of the operator */ +enum class QuantizeType { + // This operator doesn't support quantization + kNone = 0, + // This operator can get huge benefit from quantization, thus must be quantized + kMust, + // This operator support quantization, but will be decided depending on the connection + kSupport, +}; + /*! * \brief Operator state. This is a pointer type, its content is mutable * even if OpStatePtr is const. @@ -297,6 +307,12 @@ using FInferStorageType = std::function* in_attrs, std::vector* out_attrs)>; +/*! + * \brief Register a quantized node creation function based on the attrs of the node + * \note Register under "FQuantizedOp" for non-quantized operators + */ +using FQuantizable = std::function; + /*! * \brief Register a quantized node creation function based on the attrs of the node * \note Register under "FQuantizedOp" for non-quantized operators diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index c34e934bfa02..f324bb3f82be 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -83,8 +83,8 @@ def _quantize_params(qsym, params, th_dict): quantized_params[name] = ndarray.array([th_dict[output][1]]) return quantized_params -def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, - offline_params=None, quantized_dtype='int8'): +def _quantize_symbol(sym, excluded_symbols=None, excluded_operators=None, + offline_params=None, quantized_dtype='int8', quantize_mode='smart'): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -106,6 +106,9 @@ def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, avoided. quantized_dtype: str The quantized destination type for input data. + quantize_mode: str + The mode that quantization pass to apply. + """ num_excluded_symbols = 0 if excluded_symbols is not None: @@ -139,7 +142,8 @@ def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, mx_uint(num_offline), c_array(ctypes.c_char_p, offline), c_str(quantized_dtype), - ctypes.c_bool(True))) + ctypes.c_bool(True), + c_str(quantize_mode))) return Symbol(out) @@ -480,7 +484,7 @@ def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, - quantized_dtype='int8', logger=logging): + quantized_dtype='int8', quantize_mode='smart', logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -537,6 +541,8 @@ def quantize_model(sym, arg_params, aux_params, The quantized destination type for input data. Currently support 'int8' , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. Default value is 'int8'. + quantize_mode : str + The mode that quantization pass to apply. Support 'full' and 'smart'. 'full' means quantize all operator if possible. 'smart' means quantization pass will smartly choice which operator should be quantized. logger : Object A logging object for printing information during the process of quantization. @@ -567,7 +573,7 @@ def quantize_model(sym, arg_params, aux_params, qsym = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, excluded_operators=excluded_op_names, offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype) + quantized_dtype=quantized_dtype, quantize_mode=quantize_mode) th_dict = {} if calib_mode is not None and calib_mode != 'none': diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index be0cad648015..c8046e3c216a 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -888,7 +888,8 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, - const bool calib_quantize) { + const bool calib_quantize, + const char *quantize_mode) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); @@ -907,11 +908,13 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, offline.emplace(offline_params[i]); } std::string quantized_type(quantized_dtype); + std::string quantized_mode(quantize_mode); g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["excluded_ops"] = std::make_shared(std::move(excluded_op)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); g.attrs["target_ctx"] = std::make_shared(target_dev); + g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); g = ApplyPass(std::move(g), "QuantizeGraph"); s->outputs = g.outputs; *ret_sym_handle = s; diff --git a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h new file mode 100644 index 000000000000..376db076c28e --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_flatten.cc + * \brief Implement flatten operator by using mkldnn reorder primitive + * \author Wuxun Zhang + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "mkldnn_reshape-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNFlattenFwd : public MKLDNNReshapeFwd { + public: + explicit MKLDNNFlattenFwd(const OpReqType &req, const NDArray &input, const NDArray &output) + : MKLDNNReshapeFwd(req, input, output) {} +}; + +void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const NDArray &input, + const OpReqType &req, const NDArray &output); + +} // namespace op +} // namespace mxnet + +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc index fdc02f960009..4090eb026cfc 100644 --- a/src/operator/nn/mkldnn/mkldnn_flatten.cc +++ b/src/operator/nn/mkldnn/mkldnn_flatten.cc @@ -25,19 +25,11 @@ #if MXNET_USE_MKLDNN == 1 -#include "mkldnn_reshape-inl.h" +#include "mkldnn_flatten-inl.h" namespace mxnet { namespace op { -class MKLDNNFlattenFwd : public MKLDNNReshapeFwd { - public: - explicit MKLDNNFlattenFwd(const OpReqType &req, - const NDArray &input, - const NDArray &output) - : MKLDNNReshapeFwd(req, input, output) {} -}; - static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req, const NDArray &input, const NDArray &output) { diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc new file mode 100644 index 000000000000..31da936915e6 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file quantized_flatten.cc + * \brief + */ + +#if MXNET_USE_MKLDNN == 1 +#include "../../nn/mkldnn/mkldnn_flatten-inl.h" +#include "../quantization_utils.h" + +namespace mxnet { +namespace op { + +inline static bool FlattenStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, + DispatchMode* dispatch_mode, std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 3U); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} + +static void MKLDNNQuantizedFlattenForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]); + outputs[1].data().dptr()[0] = inputs[1].data().dptr()[0]; + outputs[2].data().dptr()[0] = inputs[2].data().dptr()[0]; +} + +NNVM_REGISTER_OP(_contrib_quantized_flatten) +.set_attr("FInferStorageType", FlattenStorageType) +.set_attr("FComputeEx", MKLDNNQuantizedFlattenForward) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("TIsMKLDNN", true); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index e3139913332d..e7bfdaa50aa0 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -22,10 +22,14 @@ * \file quantization.cc * \brief */ + +#include #include #include -#include +#include +#include #include +#include #include "quantize_v2-inl.h" namespace mxnet { @@ -116,17 +120,17 @@ bool isRegistered(NodePtr node, const int& dev_type) { fcomputestateful != nullptr || fcomputestateful_ex != nullptr); } -inline NodePtr NeedQuantize( - NodePtr node, const std::unordered_set& excluded_nodes, - const std::unordered_set& excluded_ops, - const int& dev_type) { +inline NodePtr NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes, + const std::unordered_set& excluded_ops, + const int& dev_type, + std::unordered_map* quantized_node_map) { std::unordered_map quantized_node; static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); const auto& op = node->op(); - + bool need = false; if (op && quantized_op_map.count(op)) { - bool need = true; + need = true; // If the quantized node is not registered with a computation function, the node // will be excluded automatically. auto q_ptr = quantized_op_map[node->op()]; @@ -154,20 +158,99 @@ inline NodePtr NeedQuantize( } } } - if (need) { - auto n_ptr = quantized_op_map[node->op()]; - auto tmp_node = n_ptr(node->attrs); - if (tmp_node->op()) { - quantized_node[node] = tmp_node; - } else { - quantized_node[node] = nullptr; + auto quantized_node = quantized_op_map[op](node->attrs); + if (!quantized_node->op()) need = false; + if (need) { + quantized_node_map->insert(std::make_pair(node, quantized_node)); + } + if (quantizable_map.count(op)) { + return quantizable_map[op](node->attrs); + } + else { + return QuantizeType::kSupport; } - } else { - quantized_node[node] = nullptr; } } - return quantized_node[node]; + CHECK(!need); + return QuantizeType::kNone; +} + +enum quantize_bit { + kFromInput = 1, + kFromOutput = 2, +}; + +static void MarkQuantizedNodes(const Graph& src, + std::unordered_map& quantized_node_map) { + const auto excluded_nodes = src.GetAttr>("excluded_nodes"); + const auto quantize_mode = src.GetAttr("quantize_mode"); + + std::unordered_map> node_output_map; + std::unordered_set must_quantize_nodes; + std::unordered_map support_quantize_nodes; + // Build node_output_map, must_quantize_nodes and support_quantize_nodes; + DFSVisit(src.outputs, [&](const NodePtr& node) { + auto quantize_type = NeedQuantize(node, excluded_nodes, &quantized_node_map); + if (quantize_type == QuantizeType::kMust) { + must_quantize_nodes.insert(node); + } else if (quantize_type == QuantizeType::kSupport) { + support_quantize_nodes[node] = 0; + } + for (size_t i = 0; i < node->inputs.size(); ++i) { + node_output_map[node->inputs[i].node].push_back(node); + } + }); + + if (quantize_mode == "full") { + return; + } else if (quantize_mode == "smart") { + // Mark quantized nodes from input + std::queue task_queue; + for (const auto& node : must_quantize_nodes) { + task_queue.push(node); + } + while (!task_queue.empty()) { + const auto& node = task_queue.front(); + task_queue.pop(); + for (size_t i = 0; i < node->inputs.size(); ++i) { + const auto& input = node->inputs[i].node; + auto it = support_quantize_nodes.find(input); + if (it != support_quantize_nodes.end()) { + it->second = it->second | kFromInput; + task_queue.push(input); + } + } + } + + // Mark quantized nodes from output + for (const auto& node : must_quantize_nodes) { + task_queue.push(node); + } + while (!task_queue.empty()) { + const auto& node = task_queue.front(); + task_queue.pop(); + const auto& outputs = node_output_map[node]; + for (size_t i = 0; i < outputs.size(); ++i) { + const auto& output = outputs[i]; + auto it = support_quantize_nodes.find(output); + if (it != support_quantize_nodes.end()) { + it->second = it->second | kFromOutput; + task_queue.push(output); + } + } + } + + // Summarize the result + for (const auto& node : support_quantize_nodes) { + CHECK(quantized_node_map.count(node.first)); + if (node.second != (kFromInput | kFromOutput)) { + quantized_node_map.erase(node.first); + } + } + } else { + LOG(FATAL) << "unrecognized quantize mode: " << quantize_mode; + } } Graph QuantizeGraph(Graph &&src) { @@ -181,6 +264,9 @@ Graph QuantizeGraph(Graph &&src) { const auto excluded_ops = src.GetAttr>("excluded_ops"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); + std::unordered_map quantized_node_map; + MarkQuantizedNodes(src, quantized_node_map); + // mirror_map stores the mapping from the currently visited graph to the newly created quantized // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key // node. The existing key's value may be updated with the newly created quantize/dequantize op. @@ -190,9 +276,9 @@ Graph QuantizeGraph(Graph &&src) { NodePtr new_node = Node::Create(); // If the currently visited node needs quantization, insert a quantize op node before the // current node and replace the current node with the quantized version in the new graph. - auto tmp_node = NeedQuantize(node, excluded_nodes, excluded_ops, dev_type); - if (tmp_node) { - new_node = tmp_node; + if (quantized_node_map.count(node)) { + LOG(INFO) << node->attrs.name << " is quantized."; + new_node = quantized_node_map[node]; // add data into quantized op input for (size_t i = 0; i < node->inputs.size(); ++i) { @@ -208,7 +294,7 @@ Graph QuantizeGraph(Graph &&src) { if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i)) { new_node->inputs.emplace_back(mirror_entry); - } else if (!NeedQuantize(e.node, excluded_nodes, excluded_ops, dev_type)) { + } else if (!quantized_node_map.count(e.node)) { if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back(mirror_entry_map[e]); } else { @@ -261,7 +347,7 @@ Graph QuantizeGraph(Graph &&src) { // skip non-quantized input continue; } - if (NeedQuantize(e.node, excluded_nodes, excluded_ops, dev_type)) { + if (quantized_node_map.count(e.node)) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1min and 1max output from mirror node (which is @@ -306,6 +392,7 @@ Graph QuantizeGraph(Graph &&src) { // (e.g., a quantized_conv2d node), and insert a dequantize op node in the new graph if there // are any. Otherwise, simply add a copy of the current node's entry to the inputs of // the new_node. + if (!node->is_variable()) LOG(INFO) << node->attrs.name << " is NOT quantized."; *new_node = *node; new_node->inputs.clear(); for (const auto& e : node->inputs) { @@ -313,7 +400,7 @@ Graph QuantizeGraph(Graph &&src) { NodeEntry mirror_entry = NodeEntry{ mirror_node, e.index, e.version}; // if input node is quantized operator, add dequantize node - if (NeedQuantize(e.node, excluded_nodes, excluded_ops, dev_type) && + if (quantized_node_map.count(e.node) && (mirror_node->op() != Op::Get("_contrib_dequantize"))) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that @@ -344,7 +431,7 @@ Graph QuantizeGraph(Graph &&src) { std::vector outputs; for (const auto& e : src.outputs) { - if (NeedQuantize(e.node, excluded_nodes, excluded_ops, dev_type)) { + if (quantized_node_map.count(e.node)) { // Only insert dequantize for those Ops supports quantize and not excluded. NodePtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc index aa3f5ce1ad61..9d774ddf24f1 100644 --- a/src/operator/quantization/quantized_conv.cc +++ b/src/operator/quantization/quantized_conv.cc @@ -180,6 +180,9 @@ and max thresholds representing the threholds for quantizing the float32 output .add_arguments(ConvolutionParam::__FIELDS__()); NNVM_REGISTER_OP(Convolution) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) .set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { nnvm::NodePtr node = nnvm::Node::Create(); node->attrs.op = Op::Get("_contrib_quantized_conv"); diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index 23790ca78b3d..4c9d9d2f8095 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -361,6 +361,9 @@ and max thresholds representing the threholds for quantizing the float32 output .add_arguments(FullyConnectedParam::__FIELDS__()); NNVM_REGISTER_OP(FullyConnected) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) .set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { nnvm::NodePtr node = nnvm::Node::Create(); node->attrs.op = Op::Get("_contrib_quantized_fully_connected"); diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index b7776d648e18..a444604ea292 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -792,6 +792,9 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) DefaultSubgraphOpMutableInputs) .set_attr("key_var_num_args", "num_args") .set_attr("FInplaceOption", SgMKLDNNConvInplaceOption) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) .set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .set_attr("FAvoidQuantizeInput", SgMKLDNNAvoidQuantizeInput); diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index f345a18c18a6..56baa5da7b9d 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -439,6 +439,9 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected) .set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs) .set_attr("key_var_num_args", "num_args") +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) .set_attr("FQuantizedOp", SgMKLDNNFCQuantizedOp) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }); diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index f7d03a3519cc..97042141659f 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -167,7 +167,8 @@ def check_quantize(sym, data_shape, out_type, name='conv', calib_data=calib_data, calib_layer=None, label_names=None, - num_calib_examples=1) + num_calib_examples=1, + quantize_mode='full') qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) @@ -227,7 +228,8 @@ def check_quantize_whole_model(out_type): calib_data=calib_data, calib_layer=calib_layer, label_names=None, - num_calib_examples=1) + num_calib_examples=1, + quantize_mode='full') qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') check_qsym_forward(qsym, qarg_params, qaux_params, data_shape) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index a1c23fb23208..061c5f762507 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -690,7 +690,8 @@ def test_quantize_params(): params = {} for name in offline_params: params[name] = mx.nd.uniform(shape=(2, 2)) - qsym = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), offline_params=offline_params) + qsym = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), + offline_params=offline_params, quantize_mode='full') qparams = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) param_names = params.keys() qparam_names = qparams.keys() @@ -806,7 +807,8 @@ def check_qsym_qdtype(qsym, qdtype): aux_params=aux_params, ctx=mx.current_context(), quantized_dtype=qdtype, - calib_mode='none') + calib_mode='none', + quantize_mode='full') check_params(arg_params, qarg_params, qsym) check_params(aux_params, qaux_params) @@ -820,7 +822,8 @@ def check_qsym_qdtype(qsym, qdtype): quantized_dtype=qdtype, calib_mode='naive', calib_data=calib_data, - num_calib_examples=20) + num_calib_examples=20, + quantize_mode='full') check_params(arg_params, qarg_params, qsym) check_params(aux_params, qaux_params) check_qsym_calibrated(qsym) @@ -966,7 +969,8 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N excluded_op_names=excluded_op_names, ctx=mx.current_context(), quantized_dtype=qdtype, - calib_mode='none') + calib_mode='none', + quantize_model='full') check_params(arg_params, qarg_params, qsym) check_params(aux_params, qaux_params) check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) @@ -983,7 +987,8 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N quantized_dtype=qdtype, calib_mode='naive', calib_data=calib_data, - num_calib_examples=20) + num_calib_examples=20, + quantize_model='full') check_params(arg_params, qarg_params, qsym) check_params(aux_params, qaux_params) check_qsym_calibrated(qsym) @@ -1050,7 +1055,8 @@ def test_quantize_sym_with_calib(): sym = get_fp32_sym() offline_params = [name for name in sym.list_arguments() if not name.startswith('data') and not name.endswith('label')] - qsym = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), offline_params=offline_params) + qsym = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), + offline_params=offline_params, quantize_mode='full') requantize_op_names = ['requantize_conv', 'requantize_fc'] th_dict = {'conv_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0)), 'fc_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0))} From 10543c9ef73b90154d477085d5ce5e2f7416eec6 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 14 Aug 2019 15:14:17 +0800 Subject: [PATCH 02/14] Add calibrate op Change-Id: I4c82f64dbef501d2560f7ffee93991119a66a5ee --- example/quantization/imagenet_gen_qsym.py | 17 +- .../quantization/imagenet_gen_qsym_mkldnn.py | 3 +- example/ssd/quantization.py | 6 +- include/mxnet/c_api.h | 5 +- include/mxnet/c_api_error.h | 12 - include/mxnet/op_attr_types.h | 14 + python/mxnet/contrib/quantization.py | 263 +++++++----------- src/c_api/c_api_symbolic.cc | 14 +- src/operator/quantization/calibrate-inl.h | 48 ++++ src/operator/quantization/calibrate.cc | 211 ++++++++++++++ .../mkldnn/mkldnn_quantized_elemwise_add.cc | 4 +- .../quantization/quantization_utils.h | 4 +- .../quantization/quantize_graph_pass.cc | 182 ++++++------ src/operator/quantization/quantize_v2.cc | 3 + .../quantization/quantized_batch_norm.cc | 3 + src/operator/quantization/requantize.cc | 3 + tests/python/mkl/test_subgraph.py | 3 - .../python/quantization/test_quantization.py | 4 + 18 files changed, 501 insertions(+), 298 deletions(-) create mode 100644 src/operator/quantization/calibrate-inl.h create mode 100644 src/operator/quantization/calibrate.cc diff --git a/example/quantization/imagenet_gen_qsym.py b/example/quantization/imagenet_gen_qsym.py index 27a1d7a45a68..ec4c809a79d0 100644 --- a/example/quantization/imagenet_gen_qsym.py +++ b/example/quantization/imagenet_gen_qsym.py @@ -141,23 +141,12 @@ def save_params(fname, arg_params, aux_params, logger=None): excluded_op_names = [] if args.model == 'imagenet1k-resnet-152': rgb_mean = '0,0,0' - if args.ctx == 'gpu': - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 - or name.find('sc') != -1 - or name.find('fc') != -1) - else: - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 - or name.find('sc') != -1) - excluded_sym_names += ['flatten0', 'fc1'] + excluded_sym_names += ['flatten0', 'fc1'] if exclude_first_conv: excluded_sym_names += ['conv0'] elif args.model == 'imagenet1k-inception-bn': rgb_mean = '123.68,116.779,103.939' - if args.ctx == 'gpu': - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 - or name.find('fc') != -1) - else: - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1) + if args.ctx == 'cpu': excluded_sym_names += ['flatten', 'fc1'] if exclude_first_conv: excluded_sym_names += ['conv_1'] @@ -203,7 +192,7 @@ def save_params(fname, arg_params, aux_params, logger=None): excluded_op_names=excluded_op_names, calib_mode=calib_mode, calib_data=data, num_calib_examples=num_calib_batches * batch_size, - calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, + quantized_dtype=args.quantized_dtype, logger=logger) if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 9b69c8088cc0..3c6931f81995 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -193,7 +193,6 @@ def save_params(fname, arg_params, aux_params, logger=None): # get image shape image_shape = args.image_shape - calib_layer = lambda name: name.endswith('_output') or name == "data" exclude_first_conv = args.exclude_first_conv if args.quantized_dtype == "uint8": logger.info('quantized dtype is set to uint8, will exclude first conv.') @@ -288,7 +287,7 @@ def save_params(fname, arg_params, aux_params, logger=None): ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, calib_data=data, num_calib_examples=num_calib_batches * batch_size, - calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, + quantized_dtype=args.quantized_dtype, label_names=(label_name,), logger=logger) if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py index b92aadcf47ac..9801a7dc0076 100644 --- a/example/ssd/quantization.py +++ b/example/ssd/quantization.py @@ -115,10 +115,6 @@ def save_params(fname, arg_params, aux_params, logger=None): # get image shape image_shape = '3,300,300' - def calib_layer(name): return not (name.endswith('_data') or - name.endswith('_weight') or - name.endswith('_bias') or - name.endswith('_workspace')) # Quantization layer configs exclude_first_conv = args.exclude_first_conv excluded_sym_names = [] @@ -154,7 +150,7 @@ def calib_layer(name): return not (name.endswith('_data') or ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, calib_data=eval_iter, num_calib_examples=num_calib_batches * batch_size, - calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, + quantized_dtype=args.quantized_dtype, label_names=(label_name,), logger=logger) sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300') param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index dfe0a9f98b25..e0ed316a1bdc 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1908,6 +1908,8 @@ MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym, * \param quantized_dtype the quantized destination type for input data * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could * \param quantize_mode quantize mode to be used in quantize pass + * \param out_num_calib_names return the number of nodes to be calibrated + * \param out_calib_names return the node names to be calibrated */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const int* dev_type, @@ -1917,7 +1919,8 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_ha const char **excluded_op_names, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, const bool calib_quantize, - const char *quantize_mode); + const char *quantize_mode, mx_uint* out_num_calib_names, + const char ***out_calib_names); /*! * \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting diff --git a/include/mxnet/c_api_error.h b/include/mxnet/c_api_error.h index e76a2c99f8d3..b10bcc1d0983 100644 --- a/include/mxnet/c_api_error.h +++ b/include/mxnet/c_api_error.h @@ -32,23 +32,11 @@ * The finally clause contains procedure to cleanup states when an error happens. */ #define MX_API_BEGIN() \ - try { \ on_enter_api(__FUNCTION__); #define MX_API_END() \ - } \ - catch (const std::exception &_except_) { \ - on_exit_api(); \ - return MXAPIHandleException(_except_); \ - } \ on_exit_api(); \ return 0; // NOLINT(*) #define MX_API_END_HANDLE_ERROR(Finalize) \ - } \ - catch (const std::exception &_except_) { \ - Finalize; \ - on_exit_api(); \ - return MXAPIHandleException(_except_); \ - } \ on_exit_api(); \ return 0; // NOLINT(*) /*! diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 9f4f5d3eadd2..75d843c98bd2 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -335,6 +335,20 @@ using FNeedRequantize = std::function; using FAvoidQuantizeInput = std::function; +/*! + * \brief Register a function to determine if the input of a quantized operator + * needs to be calibrated. This is usually used for the quantized operators + * which need calibration on its input. + */ +using FNeedCalibrateInput = std::function (const NodeAttrs& attrs)>; + +/*! + * \brief Register a function to determine if the output of a quantized operator + * needs to be calibrated. This is usually used for the quantized operators + * which need calibration on its output. + */ +using FNeedCalibrateOutput = std::function (const NodeAttrs& attrs)>; + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index f324bb3f82be..e59069f642b1 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -23,6 +23,7 @@ except ImportError: stats = None +import sys import ctypes import logging import os @@ -132,6 +133,8 @@ def _quantize_symbol(sym, excluded_symbols=None, excluded_operators=None, offline.append(c_str(k)) out = SymbolHandle() + size = mx_uint() + calib_str = ctypes.POINTER(ctypes.c_char_p)() check_call(_LIB.MXQuantizeSymbol(sym.handle, ctypes.byref(out), ctypes.byref(ctypes.c_int(ctx.device_typeid)), @@ -143,47 +146,72 @@ def _quantize_symbol(sym, excluded_symbols=None, excluded_operators=None, c_array(ctypes.c_char_p, offline), c_str(quantized_dtype), ctypes.c_bool(True), - c_str(quantize_mode))) - return Symbol(out) - - -class _LayerOutputCollector(object): - """Saves layer output NDArray in a dict with layer names as keys and lists of NDArrays as - values. The collected NDArrays will be used for calculating the optimal thresholds for + c_str(quantize_mode), + ctypes.byref(size), + ctypes.byref(calib_str))) + calib_layer = [] + calib_layer = [py_str(calib_str[i]) for i in range(size.value)] + return Symbol(out), calib_layer + +def combine_histogram(old_hist, arr, new_min, new_max, new_th): + (old_hist, old_hist_edges, old_min, old_max, old_th) = old_hist + if new_th <= old_th: + hist, _ = np.histogram(arr, bins=len(old_hist), range=(-old_th, old_th)) + return (old_hist + hist, old_hist_edges, min(old_min, new_min), max(old_max, new_max), old_th) + else: + # Need to generate new histogram with new_th + old_num_bins = len(old_hist) + old_step = 2 * old_th / old_num_bins + half_increased_bins = int((new_th - old_th) // old_step + 1) + new_num_bins = half_increased_bins * 2 + old_num_bins + new_th = half_increased_bins * old_step + old_th + hist, hist_edges = np.histogram(arr, bins=new_num_bins, range=(-new_th, new_th)) + hist[half_increased_bins:new_num_bins - half_increased_bins] += old_hist + return (hist, hist_edges, min(old_min, new_min), max(old_max, new_max), new_th) + +class _LayerHistogramCollector(object): + """Saves layer histogram in a dict with layer names as keys and lists of NDArrays as + values. The collected histogram will be used for calculating the optimal thresholds for quantization using KL divergence. """ - def __init__(self, include_layer=None, logger=None): - self.nd_dict = {} + def __init__(self, num_bins=8001, include_layer=None, logger=None): + self.hist_dict = {} + self.num_bins = num_bins self.include_layer = include_layer self.logger = logger def collect(self, name, arr): """Callback function for collecting layer output NDArrays.""" name = py_str(name) - if self.include_layer is not None and not self.include_layer(name): + if name not in self.include_layer: return handle = ctypes.cast(arr, NDArrayHandle) - arr = NDArray(handle, writable=False).copyto(cpu()) + arr = NDArray(handle, writable=False).copyto(cpu()).asnumpy() if self.logger is not None: - self.logger.info("Collecting layer %s output of shape %s" % (name, arr.shape)) - if name in self.nd_dict: - self.nd_dict[name].append(arr) + self.logger.info("Collecting layer %s histogram of shape %s" % (name, arr.shape)) + min_range = np.min(arr) + max_range = np.max(arr) + th = max(abs(min_range), abs(max_range)) + if name in self.hist_dict: + self.hist_dict[name] = combine_histogram(self.hist_dict[name], arr, min_range, max_range, th) else: - self.nd_dict[name] = [arr] + hist, hist_edges = np.histogram(arr, bins=self.num_bins, range=(-th, th)) + self.hist_dict[name] = (hist, hist_edges, min_range, max_range, th) class _LayerOutputMinMaxCollector(object): """Saves layer output min and max values in a dict with layer names as keys. The collected min and max values will be directly used as thresholds for quantization. """ - def __init__(self, include_layer=None, logger=None): + def __init__(self, quantized_dtype, include_layer=None, logger=None): self.min_max_dict = {} + self.quantized_dtype = quantized_dtype self.include_layer = include_layer self.logger = logger def collect(self, name, arr): """Callback function for collecting min and max values from an NDArray.""" name = py_str(name) - if self.include_layer is not None and not self.include_layer(name): + if name not in self.include_layer: return handle = ctypes.cast(arr, NDArrayHandle) arr = NDArray(handle, writable=False) @@ -243,21 +271,23 @@ def _collect_layer_statistics(mod, data, collector, max_num_examples=None, logge return num_examples -def _collect_layer_output_min_max(mod, data, include_layer=None, +def _collect_layer_output_min_max(mod, data, quantized_dtype, include_layer=None, max_num_examples=None, logger=None): """Collect min and max values from layer outputs and save them in a dictionary mapped by layer names. """ - collector = _LayerOutputMinMaxCollector(include_layer=include_layer, logger=logger) + collector = _LayerOutputMinMaxCollector(quantized_dtype=quantized_dtype, + include_layer=include_layer, logger=logger) num_examples = _collect_layer_statistics(mod, data, collector, max_num_examples, logger) return collector.min_max_dict, num_examples -def _collect_layer_outputs(mod, data, include_layer=None, max_num_examples=None, logger=None): +def _collect_layer_histogram(mod, data, quantized_dtype, include_layer=None, + max_num_examples=None, logger=None): """Collect layer outputs and save them in a dictionary mapped by layer names.""" - collector = _LayerOutputCollector(include_layer=include_layer, logger=logger) + collector = _LayerHistogramCollector(include_layer=include_layer, logger=logger) num_examples = _collect_layer_statistics(mod, data, collector, max_num_examples, logger) - return collector.nd_dict, num_examples + return collector.hist_dict, num_examples def _smooth_distribution(p, eps=0.0001): @@ -281,126 +311,55 @@ def _smooth_distribution(p, eps=0.0001): # pylint: disable=line-too-long -def _get_optimal_threshold(arr, quantized_dtype, num_bins=8001, num_quantized_bins=255): +def _get_optimal_threshold(hist_data, quantized_dtype, num_quantized_bins=255): """Given a dataset, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. `q` is a truncated version of the original distribution. Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ - if isinstance(arr, NDArray): - arr = arr.asnumpy() - elif isinstance(arr, list): - assert len(arr) != 0 - for i, nd in enumerate(arr): - if isinstance(nd, NDArray): - arr[i] = nd.asnumpy() - elif not isinstance(nd, np.ndarray): - raise TypeError('get_optimal_threshold only supports input type of NDArray,' - ' list of np.ndarrays or NDArrays, and np.ndarray,' - ' while received type=%s' % (str(type(nd)))) - arr = np.concatenate(arr) - elif not isinstance(arr, np.ndarray): - raise TypeError('get_optimal_threshold only supports input type of NDArray,' - ' list of NDArrays and np.ndarray,' - ' while received type=%s' % (str(type(arr)))) - min_val = np.min(arr) - max_val = np.max(arr) - th = max(abs(min_val), abs(max_val)) - + (hist, hist_edges, min_val, max_val, _) = hist_data + num_bins = len(hist) + assert (num_bins % 2 == 1) if min_val >= 0 and quantized_dtype in ['auto', 'uint8']: # We need to move negative bins to positive bins to fit uint8 range. num_quantized_bins = num_quantized_bins * 2 + 1 - - hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) - zero_bin_idx = num_bins // 2 - num_half_quantized_bins = num_quantized_bins // 2 - - thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) - divergence = np.zeros_like(thresholds) - quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) - # i means the number of bins on half axis excluding the zero bin. - for i in range(num_quantized_bins // 2, - num_bins // 2 + 1): - p_bin_idx_start = zero_bin_idx - i - p_bin_idx_stop = zero_bin_idx + i + 1 - thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop] - sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] - - # generate reference distribution p - p = sliced_nd_hist.copy() - assert p.size % 2 == 1 - assert p.size >= num_quantized_bins - # put left outlier count in p[0] - left_outlier_count = np.sum(hist[0:p_bin_idx_start]) - p[0] += left_outlier_count - # put right outlier count in p[-1] - right_outlier_count = np.sum(hist[p_bin_idx_stop:]) - p[-1] += right_outlier_count - # is_nonzeros[k] indicates whether hist[k] is nonzero - is_nonzeros = (p != 0).astype(np.int32) - - # calculate how many bins should be merged to generate quantized distribution q - num_merged_bins = sliced_nd_hist.size // num_quantized_bins - # merge hist into num_quantized_bins bins - for j in range(num_quantized_bins): - start = j * num_merged_bins - stop = start + num_merged_bins - quantized_bins[j] = sliced_nd_hist[start:stop].sum() - quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() - # expand quantized_bins into p.size bins - q = np.zeros(sliced_nd_hist.size, dtype=np.float32) - for j in range(num_quantized_bins): - start = j * num_merged_bins - if j == num_quantized_bins - 1: - stop = len(is_nonzeros) - else: - stop = start + num_merged_bins - norm = is_nonzeros[start:stop].sum() - if norm != 0: - q[start:stop] = float(quantized_bins[j]) / float(norm) - q[p == 0] = 0 - p = _smooth_distribution(p) - # There is a chance that q is an invalid probability distribution. - try: - q = _smooth_distribution(q) - except ValueError: - divergence[i - num_half_quantized_bins] = float("inf") - divergence[i - num_half_quantized_bins] = stats.entropy(p, q) - - min_divergence_idx = np.argmin(divergence) - min_divergence = divergence[min_divergence_idx] - opt_th = thresholds[min_divergence_idx] - return min_val, max_val, min_divergence, opt_th + hist = ndarray.array(hist) + hist_edges = ndarray.array(hist_edges) + threshold, divergence = ndarray.contrib.calibrate_entropy(hist=hist, + hist_edges=hist_edges, + num_quantized_bins=num_quantized_bins) + threshold = threshold.asnumpy() + divergence = divergence.asnumpy() + return min_val, max_val, threshold, divergence # pylint: enable=line-too-long - -def _get_optimal_thresholds(nd_dict, quantized_dtype, num_bins=8001, num_quantized_bins=255, logger=None): +def _get_optimal_thresholds(hist_dict, quantized_dtype, num_quantized_bins=255, logger=None): """Given a ndarray dict, find the optimal threshold for quantizing each value of the key.""" if stats is None: raise ImportError('scipy.stats is required for running entropy mode of calculating' ' the optimal thresholds for quantizing FP32 ndarrays into int8.' ' Please check if the scipy python bindings are installed.') - assert isinstance(nd_dict, dict) + assert isinstance(hist_dict, dict) if logger is not None: logger.info('Calculating optimal thresholds for quantization using KL divergence' - ' with num_bins=%d and num_quantized_bins=%d' % (num_bins, num_quantized_bins)) + ' with num_quantized_bins=%d' % num_quantized_bins) th_dict = {} - # copy nd_dict keys since the keys() only returns a view in python3 - layer_names = list(nd_dict.keys()) + # copy hist_dict keys since the keys() only returns a view in python3 + layer_names = list(hist_dict.keys()) for name in layer_names: - assert name in nd_dict - min_val, max_val, min_divergence, opt_th = \ - _get_optimal_threshold(nd_dict[name], quantized_dtype, num_bins=num_bins, + assert name in hist_dict + min_val, max_val, th, divergence = \ + _get_optimal_threshold(hist_dict[name], quantized_dtype, num_quantized_bins=num_quantized_bins) - del nd_dict[name] # release the memory of ndarray - if min_val < 0: - th_dict[name] = (-opt_th, opt_th) + if min_val >= 0 and quantized_dtype in ['auto', 'uint8']: + th_dict[name] = (0, th) else: - th_dict[name] = (0, opt_th) + th_dict[name] = (-th, th) + del hist_dict[name] # release the memory if logger is not None: - logger.info('layer=%s, min_val=%f, max_val=%f, min_divergence=%f, optimal_threshold=%f' - % (name, min_val, max_val, min_divergence, opt_th)) + logger.info('layer=%s, min_val=%f, max_val=%f, th=%f, divergence=%f' + % (name, min_val, max_val, th, divergence)) return th_dict @@ -483,7 +442,7 @@ def _as_data_iter(calib_data): def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', - calib_data=None, num_calib_examples=None, calib_layer=None, + calib_data=None, num_calib_examples=None, quantized_dtype='int8', quantize_mode='smart', logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run @@ -532,11 +491,6 @@ def quantize_model(sym, arg_params, aux_params, num_calib_examples : int or None The maximum number of examples that user would like to use for calibration. If not provided, the whole calibration dataset will be used. - calib_layer : function - Given a layer's output name in string, return True or False for deciding whether to - calibrate this layer. If yes, the statistics of the layer's output will be collected; - otherwise, no information of the layer's output will be collected. If not provided, - all the layers' outputs that need requantization will be collected. quantized_dtype : str The quantized destination type for input data. Currently support 'int8' , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. @@ -570,10 +524,11 @@ def quantize_model(sym, arg_params, aux_params, if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8`, `uint8` or `auto`' % quantized_dtype) - qsym = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, - excluded_operators=excluded_op_names, - offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype, quantize_mode=quantize_mode) + qsym, calib_layer = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, + excluded_operators=excluded_op_names, + offline_params=list( + arg_params.keys()), + quantized_dtype=quantized_dtype, quantize_mode=quantize_mode) th_dict = {} if calib_mode is not None and calib_mode != 'none': @@ -593,23 +548,22 @@ def quantize_model(sym, arg_params, aux_params, mod.bind(for_training=False, data_shapes=calib_data.provide_data) mod.set_params(arg_params, aux_params) if calib_mode == 'entropy': - nd_dict, num_examples = _collect_layer_outputs(mod, calib_data, - include_layer=calib_layer, - max_num_examples=num_calib_examples, - logger=logger) + hist_dict, num_examples = _collect_layer_histogram(mod, calib_data, quantized_dtype, + include_layer=calib_layer, + max_num_examples=num_calib_examples, + logger=logger) logger.info('Collected layer outputs from FP32 model using %d examples' % num_examples) logger.info('Calculating optimal thresholds for quantization') - th_dict = _get_optimal_thresholds(nd_dict, quantized_dtype, logger=logger) + th_dict = _get_optimal_thresholds(hist_dict, quantized_dtype, logger=logger) elif calib_mode == 'naive': th_dict, num_examples = _collect_layer_output_min_max( - mod, calib_data, include_layer=calib_layer, max_num_examples=num_calib_examples, + mod, calib_data, quantized_dtype, include_layer=calib_layer, max_num_examples=num_calib_examples, logger=logger) logger.info('Collected layer output min/max values from FP32 model using %d examples' % num_examples) else: raise ValueError('unknown calibration mode %s received,' ' expected `none`, `naive`, or `entropy`' % calib_mode) - logger.info('Calibrating quantized symbol') qsym = _calibrate_quantized_sym(qsym, th_dict) logger.info('Quantizing parameters') @@ -621,7 +575,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, - calib_layer=None, quantized_dtype='int8', logger=logging): + quantized_dtype='int8', logger=logging): """User-level API for generating a fusion + quantized model from a FP32 model w/ or w/o calibration with Intel MKL-DNN. The backend quantized operators are only enabled for Linux systems. Please do not run @@ -648,7 +602,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params, ctx=ctx, excluded_sym_names=excluded_sym_names, excluded_op_names=excluded_op_names, calib_mode=calib_mode, calib_data=calib_data, - num_calib_examples=num_calib_examples, calib_layer=calib_layer, + num_calib_examples=num_calib_examples, quantized_dtype=quantized_dtype, logger=logger) qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') @@ -657,7 +611,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params, def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', - calib_layer=None, quantized_dtype='int8', logger=logging): + quantized_dtype='int8', logger=logging): """User-level API for generating a quantized model from a FP32 model w/o calibration and a collector for naive or entropy calibration. The backend quantized operators are only enabled for Linux systems. Please do not run @@ -688,11 +642,6 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), If calib_mode='entropy' (default mode), the thresholds for quantization will be derived such that the KL divergence between the distributions of FP32 layer outputs and quantized layer outputs is minimized based upon the calibration dataset. - calib_layer : function - Given a layer's output name in string, return True or False for deciding whether to - calibrate this layer. If yes, the statistics of the layer's output will be collected; - otherwise, no information of the layer's output will be collected. If not provided, - all the layers' outputs that need requantization will be collected. quantized_dtype : str The quantized destination type for input data. Currently support 'int8' , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. @@ -717,16 +666,17 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8`, `uint8` or `auto`' % quantized_dtype) - qsym = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, - excluded_operators=excluded_op_names, - offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype) + qsym, calib_layer = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, + excluded_operators=excluded_op_names, + offline_params=list( + arg_params.keys()), + quantized_dtype=quantized_dtype) th_dict = {} collector = None if calib_mode is not None and calib_mode != 'none': if calib_mode == 'entropy': - collector = _LayerOutputCollector( + collector = _LayerHistogramCollector( include_layer=calib_layer, logger=logger) logger.info( 'Create a layer output collector for entropy calibration.') @@ -771,11 +721,6 @@ def calib_graph(qsym, arg_params, aux_params, collector, If calib_mode='entropy' (default mode), the thresholds for quantization will be derived such that the KL divergence between the distributions of FP32 layer outputs and quantized layer outputs is minimized based upon the calibration dataset. - calib_layer : function - Given a layer's output name in string, return True or False for deciding whether to - calibrate this layer. If yes, the statistics of the layer's output will be collected; - otherwise, no information of the layer's output will be collected. If not provided, - all the layers' outputs that need requantization will be collected. quantized_dtype : str The quantized destination type for input data. Currently support 'int8' , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. @@ -793,13 +738,12 @@ def calib_graph(qsym, arg_params, aux_params, collector, if calib_mode == 'entropy': logger.info('Calculating optimal thresholds for quantization') th_dict = _get_optimal_thresholds( - collector.nd_dict, quantized_dtype, logger=logger) + collector.hist_dict, quantized_dtype, logger=logger) elif calib_mode == 'naive': th_dict = collector.min_max_dict else: raise ValueError('unknown calibration mode %s received,' ' expected `none`, `naive`, or `entropy`' % calib_mode) - logger.info('Calibrating quantized symbol') qsym = _calibrate_quantized_sym(qsym, th_dict) else: raise ValueError('please set calibration mode to naive or entropy.') @@ -846,11 +790,6 @@ def quantize_net(network, quantized_dtype='auto', If calib_mode='entropy' (default mode), the thresholds for quantization will be derived such that the KL divergence between the distributions of FP32 layer outputs and quantized layer outputs is minimized based upon the calibration dataset. - calib_layer : function - Given a layer's output name in string, return True or False for deciding whether to - calibrate this layer. If yes, the statistics of the layer's output will be collected; - otherwise, no information of the layer's output will be collected. If not provided, - all the layers' outputs that need requantization will be collected. num_calib_examples : int or None The maximum number of examples that user would like to use for calibration. If not provided, the whole calibration dataset will be used. @@ -931,7 +870,7 @@ def __exit__(self, exc_type, exc_value, traceback): qsym, qarg_params, aux_params, collector = quantize_graph( sym=symnet, arg_params=args, aux_params=auxs, ctx=ctx, excluded_sym_names=exclude_layers, excluded_op_names=exclude_operators, - calib_mode=calib_mode, calib_layer=None, quantized_dtype=quantized_dtype, logger=logger) + calib_mode=calib_mode, quantized_dtype=quantized_dtype, logger=logger) if calib_mode is not None and calib_mode != 'none': if not isinstance(ctx, Context): diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index c8046e3c216a..fecab85fe6d6 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -889,7 +889,9 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const char **offline_params, const char *quantized_dtype, const bool calib_quantize, - const char *quantize_mode) { + const char *quantize_mode, + mx_uint* out_num_calib_names, + const char ***out_calib_names) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); @@ -916,6 +918,16 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, g.attrs["target_ctx"] = std::make_shared(target_dev); g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); g = ApplyPass(std::move(g), "QuantizeGraph"); + const auto& calib_nodes =g.GetAttr>("calib_nodes"); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + ret->ret_vec_str = std::move(calib_nodes); + *out_num_calib_names = ret->ret_vec_str.size(); + ret->ret_vec_charp.clear(); + ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); + for (const auto &str : ret->ret_vec_str) { + ret->ret_vec_charp.push_back(str.c_str()); + } + *out_calib_names = dmlc::BeginPtr(ret->ret_vec_charp); s->outputs = g.outputs; *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); diff --git a/src/operator/quantization/calibrate-inl.h b/src/operator/quantization/calibrate-inl.h new file mode 100644 index 000000000000..ac7d04403be1 --- /dev/null +++ b/src/operator/quantization/calibrate-inl.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file calibraite-inl.h + * \brief Implementation of calibrate operator + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_CALIBRATE_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_CALIBRATE_INL_H_ + +#include +#include +#include "../mxnet_op.h" +#include "./quantization_utils.h" + +namespace mxnet { +namespace op { + +struct CalibrateEntropyParam : public dmlc::Parameter { + int num_quantized_bins; + DMLC_DECLARE_PARAMETER(CalibrateEntropyParam) { + DMLC_DECLARE_FIELD(num_quantized_bins) + .set_default(255) + .describe( + "The number of quantized bins."); + } +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_QUANTIZATION_CALIBRATE_INL_H_ diff --git a/src/operator/quantization/calibrate.cc b/src/operator/quantization/calibrate.cc new file mode 100644 index 000000000000..4255bf907f4c --- /dev/null +++ b/src/operator/quantization/calibrate.cc @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file calibrate.cc + * \brief + */ + +#include "./calibrate-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(CalibrateEntropyParam); + +// Given a discrete distribution (may have not been normalized to 1), +// smooth it by replacing zeros with eps multiplied by a scaling factor and taking the +// corresponding amount off the non-zero values. +std::vector SmoothDistribution(const std::vector& p, const float eps = 0.0001) { + std::vector is_zeros(p.size()); + std::vector is_nonzeros(p.size()); + { + auto it = p.begin(); + std::generate(is_zeros.begin(), is_zeros.end(), + [&it]() { return static_cast(*(it++) == 0.f); }); + } + { + auto it = p.begin(); + std::generate(is_nonzeros.begin(), is_nonzeros.end(), + [&it]() { return static_cast(*(it++) != 0.f); }); + } + + size_t n_zeros = std::accumulate(is_zeros.begin(), is_zeros.end(), 0); + size_t n_nonzeros = p.size() - n_zeros; + if (!n_nonzeros) { + // The discrete probability distribution is malformed. All entries are 0. + return std::vector(); + } + float eps1 = eps * static_cast(n_zeros) / static_cast(n_nonzeros); + if (eps1 >= 1.0) return std::vector(); + auto ret = p; + for (size_t i = 0; i < p.size(); i++) { + ret[i] += eps * is_zeros[i] - eps1 * is_nonzeros[i]; + } + return ret; +} +static float ComputeEntropy(std::vector& p, std::vector& q) { + CHECK_EQ(p.size(), q.size()); + float p_sum = std::accumulate(p.begin(), p.end(), 0.f); + float q_sum = std::accumulate(q.begin(), q.end(), 0.f); + for (auto& it : p) { + it = it / p_sum; + } + + for (auto& it : q) { + it = it / q_sum; + } + float ret = 0; + for (size_t i = 0; i < p.size(); i++) { + CHECK(p[i] > 0 && q[i] > 0); + if (p[i] && q[i]) ret += p[i] * std::log(p[i] / q[i]); + } + return ret; +} + +void CalibrateComputeCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs) { + const auto& param = nnvm::get(attrs.parsed); + const auto& hist = inputs[0]; + const auto& hist_ptr = hist.dptr(); + const auto& hist_edges = inputs[1]; + const auto& hist_edges_ptr = hist_edges.dptr(); + float* const out_threshold = outputs[0].dptr(); + float* const out_divergence = outputs[1].dptr(); + const auto num_bins = hist.Size(); + CHECK_EQ(num_bins + 1, hist_edges.Size()); + int num_quantized_bins = param.num_quantized_bins; + + const int zero_bin_idx = num_bins / 2; + const int num_half_quantized_bins = num_quantized_bins / 2; + std::vector thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f); + std::vector divergence(thresholds.size(), 0.f); + #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (index_t i = num_quantized_bins / 2; i < static_cast(num_bins / 2 + 1); i++) { + const size_t p_bin_idx_start = zero_bin_idx - i; + const size_t p_bin_idx_stop = zero_bin_idx + i + 1; + thresholds[i - num_half_quantized_bins] = hist_edges_ptr[p_bin_idx_stop]; + + std::vector sliced_nd_hist(p_bin_idx_stop - p_bin_idx_start); + std::vector p(p_bin_idx_stop - p_bin_idx_start); + p[0] = 0; + p.back() = 0; + for (size_t j = 0; j < num_bins; j++) { + if (j <= p_bin_idx_start) { + p[0] += hist_ptr[j]; + } else if (j >= p_bin_idx_stop) { + p.back() += hist_ptr[j]; + } else { + sliced_nd_hist[j - p_bin_idx_start] = hist_ptr[j]; + p[j - p_bin_idx_start] = hist_ptr[j]; + } + } + // calculate how many bins should be merged to generate quantized distribution q + const float num_merged_bins = sliced_nd_hist.size() / num_quantized_bins; + // merge hist into num_quantized_bins bins + std::vector quantized_bins(num_quantized_bins, 0); + for (index_t j = 0; j < num_quantized_bins; j++) { + const int start = std::round(j * num_merged_bins); + const int stop = std::round((j + 1) * num_merged_bins); + quantized_bins[j] = + std::accumulate(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, 0); + } + quantized_bins.back() += std::accumulate( + sliced_nd_hist.begin() + static_cast(std::round(num_quantized_bins * num_merged_bins)), + sliced_nd_hist.end(), 0); + // expand quantized_bins into p.size bins + std::vector q(sliced_nd_hist.size(), 0); + for (index_t j = 0; j < num_quantized_bins; j++) { + const int start = std::round(j * num_merged_bins); + const int stop = std::round((j + 1) * num_merged_bins); + int norm = std::count_if(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, + [](size_t i) { return i != 0; }); + if (norm) { + for (index_t k = start; k < stop; k++) { + if (p[k]) q[k] = quantized_bins[j] / norm; + } + } + } + p = SmoothDistribution(p); + q = SmoothDistribution(q); + + if (!q.size()) { + divergence[i - num_half_quantized_bins] = std::numeric_limits::infinity(); + } else { + divergence[i - num_half_quantized_bins] = ComputeEntropy(p, q); + } + } + + size_t min_divergence_idx = 0; + float min_divergence = mshadow::red::limits::MaxValue(); + for (size_t i = 0; i < divergence.size(); i++) { + if (divergence[i] < min_divergence) { + min_divergence = divergence[i]; + min_divergence_idx = i; + } + } + *out_divergence = min_divergence; + *out_threshold = thresholds[min_divergence_idx]; +} + +static inline bool CalibrateShape(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1)); + return (!shape_is_none(in_attrs->at(0))) && (!shape_is_none(in_attrs->at(1))); +} + +static inline bool CalibrateType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + CHECK(in_attrs->at(0) == mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); + return true; +} + +NNVM_REGISTER_OP(_contrib_calibrate_entropy) +.describe(R"code(Provide calibrated min/max for input histogram. + +.. Note:: + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"hist", "hist_edges"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"threshold", "divergence"}; +}) +.set_attr("FInferShape", CalibrateShape) +.set_attr("FInferType", CalibrateType) +.set_attr("FCompute", CalibrateComputeCPU) +.add_argument("hist", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") +.add_argument("hist_edges", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") +.add_arguments(CalibrateEntropyParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc index 05da99207651..2be6b2baca63 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc @@ -60,7 +60,7 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons auto dataB_mem = in_data[quantized_elemwise_add_enum::kDataB].GetMKLDNNData(); const bool is_dataA_int8 = (in_data[quantized_elemwise_add_enum::kDataA].dtype() == mshadow::kInt8); - const size_t dataA_range = is_dataA_int8 ? kInt8Range : kUint8Range; + const float dataA_range = is_dataA_int8 ? kInt8Range : kUint8Range; const float A_scale = GetScale(in_data[quantized_elemwise_add_enum::kDataA], dataA_min, @@ -72,7 +72,7 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons mkldnn::memory *rescaled_mem; // output default set as int32 - size_t output_data_range = kInt32Range; + float output_data_range = kInt32Range; auto output_data_type = mkldnn::memory::s32; // dataA && dataB are uint8 if (out_data[quantized_elemwise_add_enum::kOut].dtype() == mshadow::kInt8) { diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h index e7f7ccdf13b7..857841d0db1b 100644 --- a/src/operator/quantization/quantization_utils.h +++ b/src/operator/quantization/quantization_utils.h @@ -32,8 +32,8 @@ namespace mxnet { namespace op { -static const size_t kUint8Range = 255; -static const size_t kInt8Range = 127; +static const float kUint8Range = 255.5; +static const float kInt8Range = 127.5; static const size_t kInt32Range = 0x7fffffff; template diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index e7bfdaa50aa0..cf2d4ba16146 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -41,7 +41,7 @@ using nnvm::NodePtr; using nnvm::NodeEntry; using nnvm::Graph; -inline size_t GetNumOutputs(NodePtr node) { +static inline size_t GetNumOutputs(NodePtr node) { // Get NumOutputs, check if current node has NumVisibleOutputs function, if yes, return // num_visible_outputs size_t num_outputs = node->num_outputs(); @@ -54,6 +54,12 @@ inline size_t GetNumOutputs(NodePtr node) { return num_outputs; } +static inline std::string GetOutputName(const NodeEntry& e) { + nnvm::Symbol sym; + sym.outputs.push_back(e); + return sym.ListOutputNames()[0]; +} + NodePtr CreateNode(std::string op_name, std::string node_name) { NodePtr node = Node::Create(); node->attrs.name = node_name; @@ -271,13 +277,14 @@ Graph QuantizeGraph(Graph &&src) { // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key // node. The existing key's value may be updated with the newly created quantize/dequantize op. std::unordered_map mirror_map; + std::unordered_map reverse_mirror_map; nnvm::NodeEntryMap mirror_entry_map; DFSVisit(src.outputs, [&](const NodePtr& node) { NodePtr new_node = Node::Create(); // If the currently visited node needs quantization, insert a quantize op node before the // current node and replace the current node with the quantized version in the new graph. if (quantized_node_map.count(node)) { - LOG(INFO) << node->attrs.name << " is quantized."; + std::cout << node->attrs.name << " is quantized." << std::endl; new_node = quantized_node_map[node]; // add data into quantized op input @@ -392,7 +399,7 @@ Graph QuantizeGraph(Graph &&src) { // (e.g., a quantized_conv2d node), and insert a dequantize op node in the new graph if there // are any. Otherwise, simply add a copy of the current node's entry to the inputs of // the new_node. - if (!node->is_variable()) LOG(INFO) << node->attrs.name << " is NOT quantized."; + if (!node->is_variable()) std::cout << node->attrs.name << " is NOT quantized." << std::endl; *new_node = *node; new_node->inputs.clear(); for (const auto& e : node->inputs) { @@ -417,7 +424,8 @@ Graph QuantizeGraph(Graph &&src) { dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); new_node->inputs.emplace_back(dequantize_node, 0, 0); - mirror_map[e.node.get()] = std::move(dequantize_node); + mirror_map[e.node.get()] = dequantize_node; + reverse_mirror_map[dequantize_node] = e.node; } else if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back( mirror_entry_map[e].node->inputs[0].node, e.index, e.version); @@ -426,7 +434,8 @@ Graph QuantizeGraph(Graph &&src) { } } } - mirror_map[node.get()] = std::move(new_node); + mirror_map[node.get()] = new_node; + reverse_mirror_map[new_node] = node; }); std::vector outputs; @@ -459,103 +468,87 @@ Graph QuantizeGraph(Graph &&src) { Graph ret; ret.outputs = std::move(outputs); - return ret; -} -Graph SetCalibTableToQuantizedGraph(Graph&& g) { - static const auto& flist_outputs = - nnvm::Op::GetAttr("FListOutputNames"); - static const auto& need_requantize_map = - nnvm::Op::GetAttr("FNeedRequantize"); - const auto& calib_table = - g.GetAttr>>("calib_table"); - DFSVisit(g.outputs, [&](const NodePtr& node) { - // If the current op is requantize - // find the thresholds from the calibration table with the key equal - // to the current op's input node name, e.g. a quantized_conv2d node. - if (node->op() == Op::Get("_contrib_requantize")) { - NodePtr quantized_op_node = node->inputs[0].node; - CHECK(quantized_op_node->op() != nullptr) << quantized_op_node->attrs.name - << " must be an quantized op node"; - CHECK(need_requantize_map.count(quantized_op_node->op()) > 0 - && need_requantize_map[quantized_op_node->op()](quantized_op_node->attrs)) - << quantized_op_node->attrs.name << " op must register FNeedRequantize attr" - " and the attr func should return true"; - const std::string prefix = "quantized_"; - CHECK(std::equal(prefix.begin(), prefix.end(), quantized_op_node->attrs.name.begin())) - << "an quantized op should start with `quantized_`"; - - std::string out_data_name = quantized_op_node->attrs.name.substr(prefix.size()) + "_"; - auto list_output_names_func = flist_outputs.get(quantized_op_node->op(), nullptr); - // Here it's assumed that the quantized_op node only produces three outputs: - // out_data, min_range, and max_range. So we want to get the pre-calculated min_calib_range - // and max_calib_range from the calibration table for out_data. Here we create the output - // data name same as its constructed in GraphExecutor::ExecuteMonCallback. - if (list_output_names_func != nullptr) { - std::vector names = list_output_names_func(quantized_op_node->attrs); - CHECK_EQ(names.size(), 3U) << "ListOutputNames is expected to return three string for" - " quantized operators"; - out_data_name += names[0]; - } else { - out_data_name += "0"; - } - const auto calib_table_iter = calib_table.find(out_data_name); - if (calib_table_iter != calib_table.end()) { - node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); - node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); - node->op()->attr_parser(&(node->attrs)); - } - } else if (node->op() == Op::Get("_contrib_quantize_v2")) { - NodePtr float_op_node = node->inputs[0].node; - auto float_op_idx = node->inputs[0].index; - std::string out_data_name = float_op_node->attrs.name; - if (float_op_node->op()) { - auto list_output_names_func = flist_outputs.get(float_op_node->op(), nullptr); - // We want to get the pre-calculated min_range and max_range from the calibration table for - // out_data. Here we create the output data name same as its constructed in - // GraphExecutor::ExecuteMonCallback. - if (list_output_names_func != nullptr) { - std::vector names = list_output_names_func(float_op_node->attrs); - out_data_name += "_" + names[float_op_idx]; + static const auto& need_calib_input_map = Op::GetAttr("FNeedCalibrateInput"); + static const auto& need_calib_output_map = Op::GetAttr("FNeedCalibrateOutput"); + std::vector calib_nodes; + DFSVisit(ret.outputs, [&](const NodePtr& node) { + if (need_calib_input_map.count(node->op())) { + const auto calib_idx = need_calib_input_map[node->op()](node->attrs); + for (const auto &idx : calib_idx) { + if (reverse_mirror_map.count(node)) { + calib_nodes.push_back(GetOutputName( + {reverse_mirror_map[node], node->inputs[idx].index, node->inputs[idx].version})); } else { - out_data_name += "_" + std::to_string(float_op_idx); - } - } - const auto calib_table_iter = calib_table.find(out_data_name); - if (calib_table_iter != calib_table.end()) { - node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); - node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); - node->op()->attr_parser(&(node->attrs)); - const QuantizeV2Param& param = nnvm::get(node->attrs.parsed); - if (param.out_type == QuantizeOutType::kUint8 && - param.min_calib_range.value() < 0.0f) { - LOG(WARNING) << "Calibration statistics indicates that node `" << node->attrs.name - << "` has negative input, consider use `auto` or `int8` as out_type"; + const auto& e = node->inputs[idx]; + if (e.node->is_variable()) { + calib_nodes.push_back(e.node->attrs.name); + } else { + if (reverse_mirror_map.count(e.node)) { + const auto& fp32_in_node = reverse_mirror_map.at(e.node); + calib_nodes.push_back(GetOutputName({fp32_in_node, e.index, e.version})); + } else { + LOG(FATAL) << "Can't find calibration node for " << node->attrs.name; + } + } } } - } else if (node->op() == Op::Get("_contrib_quantized_batch_norm")) { - auto quantized_op_idx = node->inputs[0].index; - const std::string prefix = "quantized_"; - std::string out_data_name = node->attrs.name.substr(prefix.size()); - if (node->op()) { - auto list_output_names_func = flist_outputs.get(node->op(), nullptr); - // We want to get the pre-calculated min_range and max_range from the calibration table for - // out_data. Here we create the output data name same as its constructed in - // GraphExecutor::ExecuteMonCallback. - if (list_output_names_func != nullptr) { - std::vector names = list_output_names_func(node->attrs); - out_data_name += "_" + names[quantized_op_idx]; + } else if (need_calib_output_map.count(node->op())) { + const auto calib_idx = need_calib_output_map[node->op()](node->attrs); + for (const auto& idx : calib_idx) { + if (reverse_mirror_map.count(node)) { + calib_nodes.push_back(GetOutputName({reverse_mirror_map[node], static_cast(idx), 0})); } else { - out_data_name += "_" + std::to_string(quantized_op_idx); + calib_nodes.push_back(GetOutputName({node, static_cast(idx), 0})); } } + } + }); + ret.attrs["calib_nodes"] = std::make_shared(std::move(calib_nodes)); + return ret; +} - const auto calib_table_iter = calib_table.find(out_data_name); - if (calib_table_iter != calib_table.end()) { - node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); - node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); - node->op()->attr_parser(&(node->attrs)); - } +static inline void SetCalibTableForEntry( + const NodeEntry& e, const NodePtr& node, + const std::unordered_map>& calib_table) { + std::string out_data_name = GetOutputName(e); + ; + const std::string prefix = "quantized_"; + if (e.node->attrs.name.rfind(prefix, 0) == 0) { + out_data_name = out_data_name.substr(prefix.size()); + } + const auto calib_table_iter = calib_table.find(out_data_name); + if (calib_table_iter != calib_table.end()) { + std::cout << "Set calibration result to " << node->attrs.name + << " : min=" << calib_table_iter->second.first + << " max=" << calib_table_iter->second.second << std::endl; + node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); + node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); + if (node->op() && node->op()->attr_parser) node->op()->attr_parser(&(node->attrs)); + } else { + std::cout << "Can't find calibration result for " << node->attrs.name << std::endl; + } +} + +Graph SetCalibTableToQuantizedGraph(Graph&& g) { + const auto& calib_table = + g.GetAttr>>("calib_table"); + static const auto& need_calib_input_map = + Op::GetAttr("FNeedCalibrateInput"); + static const auto& need_calib_output_map = + Op::GetAttr("FNeedCalibrateOutput"); + std::cout << "Set calibration result to quantized symbol." << std::endl; + DFSVisit(g.outputs, [&](const NodePtr& node) { + if (need_calib_input_map.count(node->op())) { + const auto calib_idx = need_calib_input_map[node->op()](node->attrs); + CHECK_EQ(calib_idx.size(), 1); + const auto& idx = calib_idx[0]; + SetCalibTableForEntry(node->inputs[idx], node, calib_table); + } else if (need_calib_output_map.count(node->op())) { + const auto calib_idx = need_calib_output_map[node->op()](node->attrs); + CHECK_EQ(calib_idx.size(), 1); + const auto& idx = calib_idx[0]; + SetCalibTableForEntry({node, static_cast(idx), 0}, node, calib_table); } }); return g; @@ -564,6 +557,7 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) { NNVM_REGISTER_PASS(QuantizeGraph) .describe("") .set_body(QuantizeGraph) +.provide_graph_attr("calib_nodes") .set_change_graph(true); NNVM_REGISTER_PASS(SetCalibTableToQuantizedGraph) diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index e9017a58a82c..9a30386723be 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -114,6 +114,9 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ return std::vector{true}; }) +.set_attr("FNeedCalibrateInput", [](const NodeAttrs& attrs){ + return std::vector{0}; +}) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { diff --git a/src/operator/quantization/quantized_batch_norm.cc b/src/operator/quantization/quantized_batch_norm.cc index 3187826fe996..d52164871591 100644 --- a/src/operator/quantization/quantized_batch_norm.cc +++ b/src/operator/quantization/quantized_batch_norm.cc @@ -106,6 +106,9 @@ the float32 data into int8. .set_attr("FInferType", QuantizedBatchNormType) .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return false; }) +.set_attr("FNeedCalibrateOutput", [](const NodeAttrs& attrs){ + return std::vector{0}; +}) .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("gamma", "NDArray-or-Symbol", "gamma.") .add_argument("beta", "NDArray-or-Symbol", "beta.") diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc index 43682383b0d6..9ee299cf4ae9 100644 --- a/src/operator/quantization/requantize.cc +++ b/src/operator/quantization/requantize.cc @@ -77,6 +77,9 @@ inference accuracy. #else .set_attr("FCompute", RequantizeForward) #endif +.set_attr("FNeedCalibrateInput", [](const NodeAttrs& attrs){ + return std::vector{0}; +}) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { const RequantizeParam& param = nnvm::get(attrs.parsed); diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 97042141659f..084dd5645ab6 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -165,7 +165,6 @@ def check_quantize(sym, data_shape, out_type, name='conv', quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, - calib_layer=None, label_names=None, num_calib_examples=1, quantize_mode='full') @@ -217,7 +216,6 @@ def check_quantize_whole_model(out_type): calib_data = mx.nd.random.uniform(shape=data_shape) calib_data = mx.io.NDArrayIter(data=calib_data) calib_data = DummyIter(calib_data) - calib_layer = lambda name: name.endswith('_output') qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, arg_params=arg_params, aux_params=aux_params, @@ -226,7 +224,6 @@ def check_quantize_whole_model(out_type): quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, - calib_layer=calib_layer, label_names=None, num_calib_examples=1, quantize_mode='full') diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 061c5f762507..db448ea6e08a 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -690,8 +690,12 @@ def test_quantize_params(): params = {} for name in offline_params: params[name] = mx.nd.uniform(shape=(2, 2)) +<<<<<<< HEAD qsym = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), offline_params=offline_params, quantize_mode='full') +======= + qsym, _ = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params, quantize_mode='full') +>>>>>>> Add calibrate op qparams = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) param_names = params.keys() qparam_names = qparams.keys() From 5aa8a26ee09c0fb672266353819dae2e6de6881e Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 21 Aug 2019 19:56:03 +0800 Subject: [PATCH 03/14] Fix merge Change-Id: I068df3d4f3309bc9b950a3b869da9407282c8577 --- python/mxnet/contrib/quantization.py | 6 +++--- src/c_api/c_api_symbolic.cc | 2 +- .../quantization/quantize_graph_pass.cc | 18 ++++++++++-------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index e59069f642b1..d2f8681325bf 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -84,7 +84,7 @@ def _quantize_params(qsym, params, th_dict): quantized_params[name] = ndarray.array([th_dict[output][1]]) return quantized_params -def _quantize_symbol(sym, excluded_symbols=None, excluded_operators=None, +def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, offline_params=None, quantized_dtype='int8', quantize_mode='smart'): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -524,7 +524,7 @@ def quantize_model(sym, arg_params, aux_params, if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8`, `uint8` or `auto`' % quantized_dtype) - qsym, calib_layer = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, + qsym, calib_layer = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, excluded_operators=excluded_op_names, offline_params=list( arg_params.keys()), @@ -666,7 +666,7 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8`, `uint8` or `auto`' % quantized_dtype) - qsym, calib_layer = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, + qsym, calib_layer = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, excluded_operators=excluded_op_names, offline_params=list( arg_params.keys()), diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index fecab85fe6d6..0ed50c347df0 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -919,7 +919,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); g = ApplyPass(std::move(g), "QuantizeGraph"); const auto& calib_nodes =g.GetAttr>("calib_nodes"); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_vec_str = std::move(calib_nodes); *out_num_calib_names = ret->ret_vec_str.size(); ret->ret_vec_charp.clear(); diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index cf2d4ba16146..77d5149429d3 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -126,11 +126,13 @@ bool isRegistered(NodePtr node, const int& dev_type) { fcomputestateful != nullptr || fcomputestateful_ex != nullptr); } -inline NodePtr NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes, - const std::unordered_set& excluded_ops, - const int& dev_type, - std::unordered_map* quantized_node_map) { +inline QuantizeType NeedQuantize(NodePtr node, + const std::unordered_set& excluded_nodes, + const std::unordered_set& excluded_ops, + const int& dev_type, + std::unordered_map* quantized_node_map) { std::unordered_map quantized_node; + static auto& quantizable_map = Op::GetAttr("FQuantizable"); static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); const auto& op = node->op(); @@ -190,14 +192,17 @@ enum quantize_bit { static void MarkQuantizedNodes(const Graph& src, std::unordered_map& quantized_node_map) { const auto excluded_nodes = src.GetAttr>("excluded_nodes"); + const auto excluded_ops = src.GetAttr>("excluded_ops"); const auto quantize_mode = src.GetAttr("quantize_mode"); + const auto dev_type = src.GetAttr("target_ctx"); std::unordered_map> node_output_map; std::unordered_set must_quantize_nodes; std::unordered_map support_quantize_nodes; // Build node_output_map, must_quantize_nodes and support_quantize_nodes; DFSVisit(src.outputs, [&](const NodePtr& node) { - auto quantize_type = NeedQuantize(node, excluded_nodes, &quantized_node_map); + auto quantize_type = + NeedQuantize(node, excluded_nodes, excluded_ops, dev_type, &quantized_node_map); if (quantize_type == QuantizeType::kMust) { must_quantize_nodes.insert(node); } else if (quantize_type == QuantizeType::kSupport) { @@ -264,10 +269,7 @@ Graph QuantizeGraph(Graph &&src) { static const auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); - const auto dev_type = src.GetAttr("target_ctx"); const auto offline_params = src.GetAttr>("offline_params"); - const auto excluded_nodes = src.GetAttr>("excluded_nodes"); - const auto excluded_ops = src.GetAttr>("excluded_ops"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); std::unordered_map quantized_node_map; From b02c1a7c149241c07d1010eab2a7c13968bf0be8 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 22 Aug 2019 07:54:29 +0800 Subject: [PATCH 04/14] Fix lint Change-Id: Ia38369d31c33d0f76a671275910729dfce693950 --- include/mxnet/c_api_error.h | 12 ++++++++ python/mxnet/contrib/quantization.py | 19 +++++++----- src/c_api/c_api_symbolic.cc | 2 +- src/operator/nn/mkldnn/mkldnn_flatten-inl.h | 7 +++-- src/operator/quantization/calibrate.cc | 8 +++-- .../quantization/quantize_graph_pass.cc | 29 ++++++++++--------- 6 files changed, 50 insertions(+), 27 deletions(-) diff --git a/include/mxnet/c_api_error.h b/include/mxnet/c_api_error.h index b10bcc1d0983..e76a2c99f8d3 100644 --- a/include/mxnet/c_api_error.h +++ b/include/mxnet/c_api_error.h @@ -32,11 +32,23 @@ * The finally clause contains procedure to cleanup states when an error happens. */ #define MX_API_BEGIN() \ + try { \ on_enter_api(__FUNCTION__); #define MX_API_END() \ + } \ + catch (const std::exception &_except_) { \ + on_exit_api(); \ + return MXAPIHandleException(_except_); \ + } \ on_exit_api(); \ return 0; // NOLINT(*) #define MX_API_END_HANDLE_ERROR(Finalize) \ + } \ + catch (const std::exception &_except_) { \ + Finalize; \ + on_exit_api(); \ + return MXAPIHandleException(_except_); \ + } \ on_exit_api(); \ return 0; // NOLINT(*) /*! diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index d2f8681325bf..e3414e03f18c 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -23,7 +23,6 @@ except ImportError: stats = None -import sys import ctypes import logging import os @@ -154,6 +153,8 @@ def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, return Symbol(out), calib_layer def combine_histogram(old_hist, arr, new_min, new_max, new_th): + """ Collect layer histogram for arr and combine it with old histogram. + """ (old_hist, old_hist_edges, old_min, old_max, old_th) = old_hist if new_th <= old_th: hist, _ = np.histogram(arr, bins=len(old_hist), range=(-old_th, old_th)) @@ -282,7 +283,7 @@ def _collect_layer_output_min_max(mod, data, quantized_dtype, include_layer=None return collector.min_max_dict, num_examples -def _collect_layer_histogram(mod, data, quantized_dtype, include_layer=None, +def _collect_layer_histogram(mod, data, include_layer=None, max_num_examples=None, logger=None): """Collect layer outputs and save them in a dictionary mapped by layer names.""" collector = _LayerHistogramCollector(include_layer=include_layer, logger=logger) @@ -492,11 +493,13 @@ def quantize_model(sym, arg_params, aux_params, The maximum number of examples that user would like to use for calibration. If not provided, the whole calibration dataset will be used. quantized_dtype : str - The quantized destination type for input data. Currently support 'int8' - , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + The quantized destination type for input data. Currently support 'int8', 'uint8' and 'auto'. + 'auto' means automatically select output type according to calibration result. Default value is 'int8'. quantize_mode : str - The mode that quantization pass to apply. Support 'full' and 'smart'. 'full' means quantize all operator if possible. 'smart' means quantization pass will smartly choice which operator should be quantized. + The mode that quantization pass to apply. Support 'full' and 'smart'. + 'full' means quantize all operator if possible. + 'smart' means quantization pass will smartly choice which operator should be quantized. logger : Object A logging object for printing information during the process of quantization. @@ -548,7 +551,7 @@ def quantize_model(sym, arg_params, aux_params, mod.bind(for_training=False, data_shapes=calib_data.provide_data) mod.set_params(arg_params, aux_params) if calib_mode == 'entropy': - hist_dict, num_examples = _collect_layer_histogram(mod, calib_data, quantized_dtype, + hist_dict, num_examples = _collect_layer_histogram(mod, calib_data, include_layer=calib_layer, max_num_examples=num_calib_examples, logger=logger) @@ -681,8 +684,8 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), logger.info( 'Create a layer output collector for entropy calibration.') elif calib_mode == 'naive': - collector = _LayerOutputMinMaxCollector( - include_layer=calib_layer, logger=logger) + collector = _LayerOutputMinMaxCollector(quantized_dtype=quantized_dtype, + include_layer=calib_layer, logger=logger) logger.info( 'Create a layer output minmax collector for naive calibration') else: diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 0ed50c347df0..40391082e1b5 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -918,7 +918,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, g.attrs["target_ctx"] = std::make_shared(target_dev); g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); g = ApplyPass(std::move(g), "QuantizeGraph"); - const auto& calib_nodes =g.GetAttr>("calib_nodes"); + const auto& calib_nodes = g.GetAttr>("calib_nodes"); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_vec_str = std::move(calib_nodes); *out_num_calib_names = ret->ret_vec_str.size(); diff --git a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h index 376db076c28e..ae890d8f3d91 100644 --- a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h @@ -18,11 +18,13 @@ */ /*! - * \file mkldnn_flatten.cc + * \file mkldnn_flatten-inl.h * \brief Implement flatten operator by using mkldnn reorder primitive * \author Wuxun Zhang */ +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ #if MXNET_USE_MKLDNN == 1 #include "mkldnn_reshape-inl.h" @@ -42,4 +44,5 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, co } // namespace op } // namespace mxnet -#endif +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ diff --git a/src/operator/quantization/calibrate.cc b/src/operator/quantization/calibrate.cc index 4255bf907f4c..26789c537b0f 100644 --- a/src/operator/quantization/calibrate.cc +++ b/src/operator/quantization/calibrate.cc @@ -23,6 +23,7 @@ * \brief */ +#include #include "./calibrate-inl.h" namespace mxnet { @@ -61,7 +62,10 @@ std::vector SmoothDistribution(const std::vector& p, const float e } return ret; } -static float ComputeEntropy(std::vector& p, std::vector& q) { + +static float ComputeEntropy(std::vector* p_ptr, std::vector* q_ptr) { + std::vector& p = *p_ptr; + std::vector& q = *q_ptr; CHECK_EQ(p.size(), q.size()); float p_sum = std::accumulate(p.begin(), p.end(), 0.f); float q_sum = std::accumulate(q.begin(), q.end(), 0.f); @@ -150,7 +154,7 @@ void CalibrateComputeCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, if (!q.size()) { divergence[i - num_half_quantized_bins] = std::numeric_limits::infinity(); } else { - divergence[i - num_half_quantized_bins] = ComputeEntropy(p, q); + divergence[i - num_half_quantized_bins] = ComputeEntropy(&p, &q); } } diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 77d5149429d3..f9a9ca48843b 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -55,9 +55,9 @@ static inline size_t GetNumOutputs(NodePtr node) { } static inline std::string GetOutputName(const NodeEntry& e) { - nnvm::Symbol sym; - sym.outputs.push_back(e); - return sym.ListOutputNames()[0]; + nnvm::Symbol sym; + sym.outputs.push_back(e); + return sym.ListOutputNames()[0]; } NodePtr CreateNode(std::string op_name, std::string node_name) { @@ -174,8 +174,7 @@ inline QuantizeType NeedQuantize(NodePtr node, } if (quantizable_map.count(op)) { return quantizable_map[op](node->attrs); - } - else { + } else { return QuantizeType::kSupport; } } @@ -190,7 +189,7 @@ enum quantize_bit { }; static void MarkQuantizedNodes(const Graph& src, - std::unordered_map& quantized_node_map) { + std::unordered_map* quantized_node_map) { const auto excluded_nodes = src.GetAttr>("excluded_nodes"); const auto excluded_ops = src.GetAttr>("excluded_ops"); const auto quantize_mode = src.GetAttr("quantize_mode"); @@ -202,7 +201,7 @@ static void MarkQuantizedNodes(const Graph& src, // Build node_output_map, must_quantize_nodes and support_quantize_nodes; DFSVisit(src.outputs, [&](const NodePtr& node) { auto quantize_type = - NeedQuantize(node, excluded_nodes, excluded_ops, dev_type, &quantized_node_map); + NeedQuantize(node, excluded_nodes, excluded_ops, dev_type, quantized_node_map); if (quantize_type == QuantizeType::kMust) { must_quantize_nodes.insert(node); } else if (quantize_type == QuantizeType::kSupport) { @@ -254,9 +253,9 @@ static void MarkQuantizedNodes(const Graph& src, // Summarize the result for (const auto& node : support_quantize_nodes) { - CHECK(quantized_node_map.count(node.first)); + CHECK(quantized_node_map->count(node.first)); if (node.second != (kFromInput | kFromOutput)) { - quantized_node_map.erase(node.first); + quantized_node_map->erase(node.first); } } } else { @@ -273,7 +272,7 @@ Graph QuantizeGraph(Graph &&src) { const auto quantized_dtype = src.GetAttr("quantized_dtype"); std::unordered_map quantized_node_map; - MarkQuantizedNodes(src, quantized_node_map); + MarkQuantizedNodes(src, &quantized_node_map); // mirror_map stores the mapping from the currently visited graph to the newly created quantized // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key @@ -471,8 +470,10 @@ Graph QuantizeGraph(Graph &&src) { Graph ret; ret.outputs = std::move(outputs); - static const auto& need_calib_input_map = Op::GetAttr("FNeedCalibrateInput"); - static const auto& need_calib_output_map = Op::GetAttr("FNeedCalibrateOutput"); + static const auto& need_calib_input_map = + Op::GetAttr("FNeedCalibrateInput"); + static const auto& need_calib_output_map = + Op::GetAttr("FNeedCalibrateOutput"); std::vector calib_nodes; DFSVisit(ret.outputs, [&](const NodePtr& node) { if (need_calib_input_map.count(node->op())) { @@ -499,7 +500,8 @@ Graph QuantizeGraph(Graph &&src) { const auto calib_idx = need_calib_output_map[node->op()](node->attrs); for (const auto& idx : calib_idx) { if (reverse_mirror_map.count(node)) { - calib_nodes.push_back(GetOutputName({reverse_mirror_map[node], static_cast(idx), 0})); + calib_nodes.push_back( + GetOutputName({reverse_mirror_map[node], static_cast(idx), 0})); } else { calib_nodes.push_back(GetOutputName({node, static_cast(idx), 0})); } @@ -514,7 +516,6 @@ static inline void SetCalibTableForEntry( const NodeEntry& e, const NodePtr& node, const std::unordered_map>& calib_table) { std::string out_data_name = GetOutputName(e); - ; const std::string prefix = "quantized_"; if (e.node->attrs.name.rfind(prefix, 0) == 0) { out_data_name = out_data_name.substr(prefix.size()); From 6e84101f1aa25f31e0656a282f5693f0f404978d Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 23 Aug 2019 09:20:05 +0800 Subject: [PATCH 05/14] Run CI From dcdde8113e66137859760bf0e3b115177190e03e Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 23 Aug 2019 12:32:01 +0800 Subject: [PATCH 06/14] Run CI From 6e309eb99885eb8b503abd94d64c4c23fb7971a4 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 23 Aug 2019 14:39:21 +0800 Subject: [PATCH 07/14] Update test_quantization.py --- tests/python/quantization/test_quantization.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index db448ea6e08a..cce24e18a8bc 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -690,12 +690,8 @@ def test_quantize_params(): params = {} for name in offline_params: params[name] = mx.nd.uniform(shape=(2, 2)) -<<<<<<< HEAD - qsym = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), - offline_params=offline_params, quantize_mode='full') -======= - qsym, _ = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params, quantize_mode='full') ->>>>>>> Add calibrate op + qsym, _ = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), + offline_params=offline_params, quantize_mode='full') qparams = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) param_names = params.keys() qparam_names = qparams.keys() From ced11b87b8aeae3c616cbe8602156bad54bfa173 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 27 Aug 2019 09:15:04 +0800 Subject: [PATCH 08/14] Run CI From 7c697f6c9f60c005dae97b47cdceb87492acb11b Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 27 Aug 2019 10:38:07 +0800 Subject: [PATCH 09/14] Fix CI Change-Id: I7479327db5ebc7c57b7bd810a67d2b765c820534 --- src/operator/quantization/calibrate.cc | 4 +-- .../mkldnn/mkldnn_quantize_v2-inl.h | 6 ++-- .../python/quantization/test_quantization.py | 35 ++++++++++++++----- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/operator/quantization/calibrate.cc b/src/operator/quantization/calibrate.cc index 26789c537b0f..408ba4b5581f 100644 --- a/src/operator/quantization/calibrate.cc +++ b/src/operator/quantization/calibrate.cc @@ -103,7 +103,7 @@ void CalibrateComputeCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, std::vector thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f); std::vector divergence(thresholds.size(), 0.f); #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (index_t i = num_quantized_bins / 2; i < static_cast(num_bins / 2 + 1); i++) { + for (index_t i = num_quantized_bins / 2; i <= zero_bin_idx; i++) { const size_t p_bin_idx_start = zero_bin_idx - i; const size_t p_bin_idx_stop = zero_bin_idx + i + 1; thresholds[i - num_half_quantized_bins] = hist_edges_ptr[p_bin_idx_stop]; @@ -123,7 +123,7 @@ void CalibrateComputeCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, } } // calculate how many bins should be merged to generate quantized distribution q - const float num_merged_bins = sliced_nd_hist.size() / num_quantized_bins; + const float num_merged_bins = static_cast(sliced_nd_hist.size()) / num_quantized_bins; // merge hist into num_quantized_bins bins std::vector quantized_bins(num_quantized_bins, 0); for (index_t j = 0; j < num_quantized_bins; j++) { diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index bd1b47e4c2de..7cdce8e32bc8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -68,10 +68,10 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector() = 0; - *outputs[2].data().dptr() = 255; + *outputs[2].data().dptr() = kUint8Range; } else { - *outputs[1].data().dptr() = -127; - *outputs[2].data().dptr() = 127; + *outputs[1].data().dptr() = -kInt8Range; + *outputs[2].data().dptr() = kInt8Range; } } if (req[0] != kWriteInplace) { diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index f5efbd1a3997..38f5caf656a8 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -80,7 +80,7 @@ def baseline_dequantization(qdata, real_range, qdata_np): def test_nd_array_dequantization(qdata, min_range, max_range, expected_result): data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') assert data.dtype == np.float32 - assert_almost_equal(data.asnumpy(), expected_result) + assert_almost_equal(data.asnumpy(), expected_result, atol = 1) def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result): sym_data = mx.sym.Variable('data') @@ -92,7 +92,7 @@ def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_resul args={'data':qdata, 'min_range':min_range, 'max_range':max_range}) data = out.forward()[0] assert data.dtype == np.float32 - assert_almost_equal(data.asnumpy(), expected_result) + assert_almost_equal(data.asnumpy(), expected_result, atol = 1) real_range = 402.3347 shape = rand_shape_nd(4) @@ -691,7 +691,7 @@ def test_quantize_params(): params = {} for name in offline_params: params[name] = mx.nd.uniform(shape=(2, 2)) - qsym, _ = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), + qsym, _ = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), offline_params=offline_params, quantize_mode='full') qparams = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) param_names = params.keys() @@ -1086,11 +1086,22 @@ def test_optimal_threshold_adversarial_case(): # The worst case for the optimal_threshold function is when the values are concentrated # at one edge: [0, 0, ..., 1000]. (histogram) # We want to make sure that the optimal threshold in this case is the max. - arr = np.array([2] * 1000) + hist = [] + hist_edges = [] + min_val = -2 + max_val = 2 + for i in range(0, 998): + hist.append(0) + for i in range(0, 999): + hist_edges.append((max_val - min_val) / 999 * i + min_val) + hist.append(1000) + hist_edges.append(max_val) + hist_data = (hist, hist_edges, min_val, max_val, max_val) for dtype in ['uint8', 'int8', 'auto']: - res = mx.contrib.quant._get_optimal_threshold(arr, dtype, num_quantized_bins=5) + res = mx.contrib.quant._get_optimal_threshold(hist_data, dtype, num_quantized_bins=5) # The threshold should be 2. - assert res[3] - 2 < 1e-5 + print (res) + assert abs(res[2] - 2) < 1e-5 @with_seed() @@ -1103,9 +1114,15 @@ def get_threshold(nd): return mx.nd.maximum(mx.nd.abs(min_nd), mx.nd.abs(max_nd)).asnumpy() for dtype in ['uint8', 'int8', 'auto']: - nd_dict = {'layer1': mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23), dtype=np.float64)} - expected_threshold = get_threshold(nd_dict['layer1']) - th_dict = mx.contrib.quant._get_optimal_thresholds(nd_dict, dtype) + nd = mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23), dtype=np.float64) + expected_threshold = get_threshold(nd) + arr = nd.asnumpy() + min_range = np.min(arr) + max_range = np.max(arr) + th = max(abs(min_range), abs(max_range)) + hist, hist_edges = np.histogram(arr, bins=8001, range=(-th, th)) + hist_dict = {'layer1' : (hist, hist_edges, min_range, max_range, th)} + th_dict = mx.contrib.quant._get_optimal_thresholds(hist_dict, dtype) assert 'layer1' in th_dict assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4) From 8f2490aebfbb2f1518a75a9f78235bfec81ae597 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 27 Aug 2019 12:55:36 +0800 Subject: [PATCH 10/14] Fix CI Change-Id: I4273938cb972c12b8f43dbd95c736a7d32df040e --- python/mxnet/contrib/quantization.py | 1 - src/common/utils.h | 10 ++++++++++ src/executor/graph_executor.cc | 14 ++------------ src/operator/quantization/calibrate.cc | 12 ++++++------ .../quantization/quantize_graph_pass.cc | 17 ++++++----------- tests/python/quantization/test_quantization.py | 2 +- 6 files changed, 25 insertions(+), 31 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index e3414e03f18c..ebb3ae4303a0 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -532,7 +532,6 @@ def quantize_model(sym, arg_params, aux_params, offline_params=list( arg_params.keys()), quantized_dtype=quantized_dtype, quantize_mode=quantize_mode) - th_dict = {} if calib_mode is not None and calib_mode != 'none': if not isinstance(ctx, Context): diff --git a/src/common/utils.h b/src/common/utils.h index 9dad5df84fd2..6edfa3c6bb56 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -804,6 +805,15 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) { } } +/*! + * \brief This is function can return the output names of a NodeEntry. + */ +static inline std::string GetOutputName(const nnvm::NodeEntry& e) { + nnvm::Symbol sym; + sym.outputs.push_back(e); + return sym.ListOutputNames()[0]; +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7bdeac708003..87cb856debb5 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1370,24 +1370,14 @@ void GraphExecutor::ExecuteMonInputCallback(size_t nid) { } void GraphExecutor::ExecuteMonOutputCallback(size_t nid) { - static const auto& flist_outputs = - nnvm::Op::GetAttr("FListOutputNames"); const auto& idx = graph_.indexed_graph(); - std::vector output_names; OpNode& opnode = op_nodes_[nid]; const auto& inode = idx[nid]; const auto& node = idx[nid].source; - if (flist_outputs.count(node->op())) { - output_names = flist_outputs[node->op()](node->attrs); - } else { - for (size_t i = 0; i < node->num_outputs(); ++i) { - output_names.emplace_back(std::to_string(i)); - } - } - CHECK_EQ(opnode.exec->out_array.size(), output_names.size()); for (size_t i = 0; i < opnode.exec->out_array.size(); ++i) { NDArray *cpy = new NDArray(opnode.exec->out_array[i]); - std::string name = inode.source->attrs.name + "_" + output_names[i]; + nnvm::NodePtr node_ptr = std::make_shared(*node); + std::string name = GetOutputName({node_ptr, i, 0}); this->monitor_callback_(name.c_str(), reinterpret_cast(cpy)); } } diff --git a/src/operator/quantization/calibrate.cc b/src/operator/quantization/calibrate.cc index 408ba4b5581f..3852eab979f8 100644 --- a/src/operator/quantization/calibrate.cc +++ b/src/operator/quantization/calibrate.cc @@ -123,23 +123,23 @@ void CalibrateComputeCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, } } // calculate how many bins should be merged to generate quantized distribution q - const float num_merged_bins = static_cast(sliced_nd_hist.size()) / num_quantized_bins; + const auto num_merged_bins = sliced_nd_hist.size() / num_quantized_bins; // merge hist into num_quantized_bins bins std::vector quantized_bins(num_quantized_bins, 0); for (index_t j = 0; j < num_quantized_bins; j++) { - const int start = std::round(j * num_merged_bins); - const int stop = std::round((j + 1) * num_merged_bins); + const int start = j * num_merged_bins; + const int stop = (j + 1) * num_merged_bins; quantized_bins[j] = std::accumulate(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, 0); } quantized_bins.back() += std::accumulate( - sliced_nd_hist.begin() + static_cast(std::round(num_quantized_bins * num_merged_bins)), + sliced_nd_hist.begin() + static_cast(num_quantized_bins * num_merged_bins), sliced_nd_hist.end(), 0); // expand quantized_bins into p.size bins std::vector q(sliced_nd_hist.size(), 0); for (index_t j = 0; j < num_quantized_bins; j++) { - const int start = std::round(j * num_merged_bins); - const int stop = std::round((j + 1) * num_merged_bins); + const int start = j * num_merged_bins; + const int stop = (j == num_quantized_bins - 1) ? q.size() : ((j + 1) * num_merged_bins); int norm = std::count_if(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, [](size_t i) { return i != 0; }); if (norm) { diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index f9a9ca48843b..182f6339308a 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -31,6 +31,7 @@ #include #include #include "quantize_v2-inl.h" +#include "../../common/utils.h" namespace mxnet { namespace op { @@ -54,12 +55,6 @@ static inline size_t GetNumOutputs(NodePtr node) { return num_outputs; } -static inline std::string GetOutputName(const NodeEntry& e) { - nnvm::Symbol sym; - sym.outputs.push_back(e); - return sym.ListOutputNames()[0]; -} - NodePtr CreateNode(std::string op_name, std::string node_name) { NodePtr node = Node::Create(); node->attrs.name = node_name; @@ -480,7 +475,7 @@ Graph QuantizeGraph(Graph &&src) { const auto calib_idx = need_calib_input_map[node->op()](node->attrs); for (const auto &idx : calib_idx) { if (reverse_mirror_map.count(node)) { - calib_nodes.push_back(GetOutputName( + calib_nodes.push_back(common::GetOutputName( {reverse_mirror_map[node], node->inputs[idx].index, node->inputs[idx].version})); } else { const auto& e = node->inputs[idx]; @@ -489,7 +484,7 @@ Graph QuantizeGraph(Graph &&src) { } else { if (reverse_mirror_map.count(e.node)) { const auto& fp32_in_node = reverse_mirror_map.at(e.node); - calib_nodes.push_back(GetOutputName({fp32_in_node, e.index, e.version})); + calib_nodes.push_back(common::GetOutputName({fp32_in_node, e.index, e.version})); } else { LOG(FATAL) << "Can't find calibration node for " << node->attrs.name; } @@ -501,9 +496,9 @@ Graph QuantizeGraph(Graph &&src) { for (const auto& idx : calib_idx) { if (reverse_mirror_map.count(node)) { calib_nodes.push_back( - GetOutputName({reverse_mirror_map[node], static_cast(idx), 0})); + common::GetOutputName({reverse_mirror_map[node], static_cast(idx), 0})); } else { - calib_nodes.push_back(GetOutputName({node, static_cast(idx), 0})); + calib_nodes.push_back(common::GetOutputName({node, static_cast(idx), 0})); } } } @@ -515,7 +510,7 @@ Graph QuantizeGraph(Graph &&src) { static inline void SetCalibTableForEntry( const NodeEntry& e, const NodePtr& node, const std::unordered_map>& calib_table) { - std::string out_data_name = GetOutputName(e); + std::string out_data_name = common::GetOutputName(e); const std::string prefix = "quantized_"; if (e.node->attrs.name.rfind(prefix, 0) == 0) { out_data_name = out_data_name.substr(prefix.size()); diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 38f5caf656a8..8b74b25c7030 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -94,7 +94,7 @@ def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_resul assert data.dtype == np.float32 assert_almost_equal(data.asnumpy(), expected_result, atol = 1) - real_range = 402.3347 + real_range = 128 shape = rand_shape_nd(4) qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8) qdata, min_range, max_range = get_test_data(real_range, qdata_np) From c3a8b94afa741cb0ee991698ec777d93646359b0 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 27 Aug 2019 14:12:50 +0800 Subject: [PATCH 11/14] Fix CI Change-Id: I80b47bd1d95520a7cd78cacbbc1a85fe0900123d --- src/executor/graph_executor.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 87cb856debb5..d92253266f35 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1372,12 +1372,11 @@ void GraphExecutor::ExecuteMonInputCallback(size_t nid) { void GraphExecutor::ExecuteMonOutputCallback(size_t nid) { const auto& idx = graph_.indexed_graph(); OpNode& opnode = op_nodes_[nid]; - const auto& inode = idx[nid]; const auto& node = idx[nid].source; for (size_t i = 0; i < opnode.exec->out_array.size(); ++i) { NDArray *cpy = new NDArray(opnode.exec->out_array[i]); nnvm::NodePtr node_ptr = std::make_shared(*node); - std::string name = GetOutputName({node_ptr, i, 0}); + std::string name = GetOutputName({node_ptr, static_cast(i), 0}); this->monitor_callback_(name.c_str(), reinterpret_cast(cpy)); } } From 3a91337b4971eda2449cf6cadb154d9041ebe092 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 28 Aug 2019 09:03:05 +0800 Subject: [PATCH 12/14] Fix CI Change-Id: I56542470010e7bc403f62dc8a8991c2fb58d229e --- tests/python/unittest/test_operator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ceee51a3e503..f864003f7c60 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5571,9 +5571,10 @@ def test_quantization_op(): qa_real = mx.nd.array([[18, 75], [77, 109]]) a_real = mx.nd.array([[0.14173228, 0.5905512], [0.6062992, 0.8582677]]) - + print(a_.asnumpy()) + print(a_real.asnumpy()) assert same(qa.asnumpy(), qa_real.asnumpy()) - assert same(a_.asnumpy(), a_real.asnumpy()) + assert_almost_equal(a_.asnumpy(), a_real.asnumpy(), rtol=1e-2) @with_seed() def test_index_copy(): From fb4fbba26b0057c0bbda17e8c04899085fdd3cb8 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 28 Aug 2019 12:12:36 +0800 Subject: [PATCH 13/14] Fix CI Change-Id: If8482fe4da2f3d627dd3cbac8795e021a09a441f --- tests/python/quantization/test_quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 8b74b25c7030..d753b514302e 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -967,7 +967,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N ctx=mx.current_context(), quantized_dtype=qdtype, calib_mode='none', - quantize_model='full') + quantize_mode='full') check_params(arg_params, qarg_params, qsym) check_params(aux_params, qaux_params) check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) @@ -985,7 +985,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N calib_mode='naive', calib_data=calib_data, num_calib_examples=20, - quantize_model='full') + quantize_mode='full') check_params(arg_params, qarg_params, qsym) check_params(aux_params, qaux_params) check_qsym_calibrated(qsym) @@ -1052,7 +1052,7 @@ def test_quantize_sym_with_calib(): sym = get_fp32_sym() offline_params = [name for name in sym.list_arguments() if not name.startswith('data') and not name.endswith('label')] - qsym = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), + qsym, _ = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), offline_params=offline_params, quantize_mode='full') requantize_op_names = ['requantize_conv', 'requantize_fc'] th_dict = {'conv_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0)), From eca88d619cbfae57e4d8bcd4e7eb2c918a74b3f1 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 29 Aug 2019 08:42:29 +0800 Subject: [PATCH 14/14] Fix GPU Change-Id: Ic239cbf7aa3d111f2895badd1cac196fce6a1b86 --- python/mxnet/contrib/quantization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index ebb3ae4303a0..ef24bddf9bd1 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -325,8 +325,8 @@ def _get_optimal_threshold(hist_data, quantized_dtype, num_quantized_bins=255): if min_val >= 0 and quantized_dtype in ['auto', 'uint8']: # We need to move negative bins to positive bins to fit uint8 range. num_quantized_bins = num_quantized_bins * 2 + 1 - hist = ndarray.array(hist) - hist_edges = ndarray.array(hist_edges) + hist = ndarray.array(hist, ctx=cpu()) + hist_edges = ndarray.array(hist_edges, ctx=cpu()) threshold, divergence = ndarray.contrib.calibrate_entropy(hist=hist, hist_edges=hist_edges, num_quantized_bins=num_quantized_bins)