From 99b0eb15169b3e6bc974f1043a43a82f7b998a62 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 11 Feb 2019 13:38:24 +0800 Subject: [PATCH 1/6] Enable int8 data layer Change-Id: I3d97ef80b7466d7555f4970e24f02e8dfba6be2b --- docs/api/perl/io.md | 1 + docs/api/python/io/io.md | 1 + .../quantization/imagenet_gen_qsym_mkldnn.py | 8 +- example/quantization/imagenet_inference.py | 84 ++++++++++++++----- include/mxnet/c_api.h | 4 +- perl-package/AI-MXNet/lib/AI/MXNet/IO.pm | 1 + python/mxnet/contrib/quantization.py | 11 ++- src/c_api/c_api_symbolic.cc | 4 +- src/io/iter_image_recordio_2.cc | 55 +++++++++--- .../quantization/quantize_graph_pass.cc | 32 ++++--- tests/python/train/test_dtype.py | 29 +++++++ 11 files changed, 181 insertions(+), 49 deletions(-) diff --git a/docs/api/perl/io.md b/docs/api/perl/io.md index 3310f26aba18..f2da7d66cc64 100644 --- a/docs/api/perl/io.md +++ b/docs/api/perl/io.md @@ -52,6 +52,7 @@ Then we can call `$mod->fit($nd_iter, num_epoch=>2)` to train `loss` by 2 epochs mx->io->NDArrayIter mx->io->CSVIter mx->io->ImageRecordIter +mx->io->ImageRecordInt8Iter mx->io->ImageRecordUInt8Iter mx->io->MNISTIter mx->recordio->MXRecordIO diff --git a/docs/api/python/io/io.md b/docs/api/python/io/io.md index a069eae61e71..f9182a29e1d0 100644 --- a/docs/api/python/io/io.md +++ b/docs/api/python/io/io.md @@ -58,6 +58,7 @@ A detailed tutorial is available at io.CSVIter io.LibSVMIter io.ImageRecordIter + io.ImageRecordInt8Iter io.ImageRecordUInt8Iter io.MNISTIter recordio.MXRecordIO diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index d807e7f2d19d..72a620dc7f64 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -147,6 +147,8 @@ def save_params(fname, arg_params, aux_params, logger=None): help='If enabled, the quantize op will ' 'be calibrated offline if calibration mode is ' 'enabled') + parser.add_argument('--use-quantized-data-layer', type=bool, default=True, + help='If enabled, data layer will be already quantized.') args = parser.parse_args() ctx = mx.cpu(0) logging.basicConfig() @@ -273,6 +275,7 @@ def save_params(fname, arg_params, aux_params, logger=None): qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, + use_quantized_data_layer=args.use_quantized_data_layer, logger=logger) sym_name = '%s-symbol.json' % (prefix + '-quantized') else: @@ -295,7 +298,10 @@ def save_params(fname, arg_params, aux_params, logger=None): calib_mode=calib_mode, calib_data=data, num_calib_examples=num_calib_batches * batch_size, calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, - label_names=(label_name,), logger=logger) + label_names=(label_name,), + use_quantized_data_layer=args.use_quantized_data_layer, + logger=logger) + if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches elif calib_mode == 'naive': diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py index 0725165b0ca5..f53afc596798 100644 --- a/example/quantization/imagenet_inference.py +++ b/example/quantization/imagenet_inference.py @@ -19,6 +19,7 @@ import logging import os import time +import numpy as np import mxnet as mx from mxnet import nd from mxnet.contrib.quantization import * @@ -98,7 +99,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger.info(m.get()) -def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): +def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, logger=None): # get mod cur_path = os.path.dirname(os.path.realpath(__file__)) symbol_file_path = os.path.join(cur_path, symbol_file) @@ -106,14 +107,28 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): logger.info('Loading symbol from file %s' % symbol_file_path) sym = mx.sym.load(symbol_file_path) mod = mx.mod.Module(symbol=sym, context=ctx) - mod.bind(for_training = False, - inputs_need_grad = False, - data_shapes = [('data', (batch_size,)+data_shape)]) + if data_layer_type == "float32": + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.float32) + elif data_layer_type == 'uint8': + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.uint8) + else: # int8 + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.int8) + mod.bind(for_training=False, + inputs_need_grad=False, + data_shapes=[dshape]) mod.init_params(initializer=mx.init.Xavier(magnitude=2.)) # get data - data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes] - batch = mx.io.DataBatch(data, []) # empty label + if data_layer_type == "float32": + data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx, dtype=data_layer_type) + for _, shape in mod.data_shapes] + else: + data = [mx.nd.full(shape=shape, val=127, ctx=ctx, dtype=data_layer_type) + for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) # empty label # run dry_run = 5 # use 5 iterations to warm up @@ -152,6 +167,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): help='shuffling seed, see' ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' ' for more details') + parser.add_argument('--data-layer-type', type=str, default="float32", + choices=['float32', 'int8', 'uint8'], + help='data type for data layer') args = parser.parse_args() @@ -192,24 +210,52 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): data_shape = tuple([int(i) for i in image_shape.split(',')]) logger.info('Input data shape = %s' % str(data_shape)) + data_layer_type = args.data_layer_type if args.benchmark == False: dataset = args.dataset download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset) logger.info('Dataset for inference: %s' % dataset) # creating data iterator - data = mx.io.ImageRecordIter(path_imgrec=dataset, - label_width=1, - preprocess_threads=data_nthreads, - batch_size=batch_size, - data_shape=data_shape, - label_name=label_name, - rand_crop=False, - rand_mirror=False, - shuffle=True, - shuffle_chunk_seed=3982304, - seed=48564309, - **combine_mean_std) + if data_layer_type == 'float32': + data = mx.io.ImageRecordIter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=True, + shuffle_chunk_seed=3982304, + seed=48564309, + **combine_mean_std) + elif data_layer_type == 'uint8': + data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=True, + shuffle_chunk_seed=3982304, + seed=48564309, + **combine_mean_std) + else: #int8 + data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=True, + shuffle_chunk_seed=3982304, + seed=48564309, + **combine_mean_std) # loading model sym, arg_params, aux_params = load_model(symbol_file, param_file, logger) @@ -224,5 +270,5 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): max_num_examples=num_inference_images, logger=logger) else: logger.info('Running model %s for inference' % symbol_file) - speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, logger) + speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, data_layer_type, logger) logger.info('batch size %2d, image/sec: %f', batch_size, speed) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d6e13ebcf051..bca97b7ee3f5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1567,12 +1567,14 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data. * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could. + * \param use_quantized_data_layer if use quantized data layer. */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, const char **excluded_symbols, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype, const bool calib_quantize); + const char *quantized_dtype, const bool calib_quantize, + const bool use_quantized_data_layer); /*! * \brief Set calibration table to node attributes in the sym diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm index 297ceb8c0b24..19e7cfdb8fe3 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm @@ -642,6 +642,7 @@ extends 'AI::MXNet::DataIter'; mx->io->CSVIter Returns the CSV file iterator. mx->io->LibSVMIter Returns the LibSVM iterator which returns data with csr storage type. mx->io->ImageRecordIter Iterates on image RecordIO files + mx->io->ImageRecordInt8Iter Iterating on image RecordIO files mx->io->ImageRecordUInt8Iter Iterating on image RecordIO files mx->io->MNISTIter Iterating on the MNIST dataset. mx->recordio->MXRecordIO Reads/writes RecordIO data format, supporting sequential read and write. diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 96183bb7a172..227dde7e7edb 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -80,7 +80,8 @@ def _quantize_params(qsym, params, th_dict): quantized_params[name] = ndarray.array([th_dict[output][1]]) return quantized_params -def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8'): +def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8', + use_quantized_data_layer=False): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -120,7 +121,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_ mx_uint(num_offline), c_array(ctypes.c_char_p, offline), c_str(quantized_dtype), - ctypes.c_bool(True))) + ctypes.c_bool(True), + ctypes.c_bool(use_quantized_data_layer))) return Symbol(out) @@ -419,7 +421,7 @@ def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, - quantized_dtype='int8', logger=logging): + quantized_dtype='int8', use_quantized_data_layer=False, logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -495,7 +497,8 @@ def quantize_model(sym, arg_params, aux_params, ' expected `int8`, `uint8` or `auto`' % quantized_dtype) qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype) + quantized_dtype=quantized_dtype, + use_quantized_data_layer = use_quantized_data_layer) th_dict = {} if calib_mode is not None and calib_mode != 'none': diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 32b63c11dd9a..34c3ef740c77 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -651,7 +651,8 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, - const bool calib_quantize) { + const bool calib_quantize, + const bool use_quantized_data_layer) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); @@ -668,6 +669,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); + g.attrs["use_quantized_data_layer"] = std::make_shared(use_quantized_data_layer); g = ApplyPass(std::move(g), "QuantizeGraph"); s->outputs = g.outputs; *ret_sym_handle = s; diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc index 00c38198659f..6d49006d857c 100644 --- a/src/io/iter_image_recordio_2.cc +++ b/src/io/iter_image_recordio_2.cc @@ -372,6 +372,7 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, float RGBA_MULT[4] = { 0 }; float RGBA_BIAS[4] = { 0 }; float RGBA_MEAN[4] = { 0 }; + int16_t RGBA_MEAN_INT[4] = {0}; mshadow::Tensor& data = (*data_ptr); if (!std::is_same::value) { RGBA_MULT[0] = contrast_scaled / normalize_param_.std_r; @@ -387,6 +388,10 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, RGBA_MEAN[1] = normalize_param_.mean_g; RGBA_MEAN[2] = normalize_param_.mean_b; RGBA_MEAN[3] = normalize_param_.mean_a; + RGBA_MEAN_INT[0] = std::round(normalize_param_.mean_r); + RGBA_MEAN_INT[1] = std::round(normalize_param_.mean_g); + RGBA_MEAN_INT[2] = std::round(normalize_param_.mean_b); + RGBA_MEAN_INT[3] = std::round(normalize_param_.mean_a); } } @@ -408,17 +413,30 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, for (int i = 0; i < res.rows; ++i) { const uchar* im_data = res.ptr(i); for (int j = 0; j < res.cols; ++j) { - for (int k = 0; k < n_channels; ++k) { - RGBA[k] = im_data[swap_indices[k]]; - } - if (!std::is_same::value) { - // normalize/mirror here to avoid memory copies - // logic from iter_normalize.h, function SetOutImg + if (std::is_same::value) { + if (meanfile_ready_) { + for (int k = 0; k < n_channels; ++k) { + RGBA[k] = cv::saturate_cast(im_data[swap_indices[k]] - + static_cast(std::round(meanimg_[k][i][j]))); + } + } else { + for (int k = 0; k < n_channels; ++k) { + RGBA[k] = cv::saturate_cast(im_data[swap_indices[k]] - RGBA_MEAN_INT[k]); + } + } + } else { for (int k = 0; k < n_channels; ++k) { - if (meanfile_ready_) { - RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k]; - } else { - RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k]; + RGBA[k] = im_data[swap_indices[k]]; + } + if (!std::is_same::value) { + // normalize/mirror here to avoid memory copies + // logic from iter_normalize.h, function SetOutImg + for (int k = 0; k < n_channels; ++k) { + if (meanfile_ready_) { + RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k]; + } else { + RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k]; + } } } } @@ -795,5 +813,22 @@ the data type instead of ``float``. .set_body([]() { return new ImageRecordIter2(); }); + +MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter) +.describe(R"code(Iterating on image RecordIO files + +This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as +the data type instead of ``float``. + +)code" ADD_FILELINE) +.add_arguments(ImageRecParserParam::__FIELDS__()) +.add_arguments(ImageRecordParam::__FIELDS__()) +.add_arguments(BatchParam::__FIELDS__()) +.add_arguments(PrefetcherParam::__FIELDS__()) +.add_arguments(ListDefaultAugParams()) +.set_body([]() { + return new ImageRecordIter2(); + }); + } // namespace io } // namespace mxnet diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index af533978a6f5..c68c41f111c1 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -126,6 +126,7 @@ Graph QuantizeGraph(Graph &&src) { const auto offline_params = src.GetAttr>("offline_params"); const auto excluded_nodes = src.GetAttr>("excluded_nodes"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); + const auto use_quantized_data_layer = src.GetAttr("use_quantized_data_layer"); // mirror_map stores the mapping from the currently visited graph to the newly created quantized // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key @@ -156,6 +157,10 @@ Graph QuantizeGraph(Graph &&src) { if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i)) { new_node->inputs.emplace_back(mirror_entry); + // If network will directly accept quantized data, simply add data as input. + } else if (use_quantized_data_layer && e.node->is_variable() && + e.node->attrs.name == "data") { + new_node->inputs.emplace_back(mirror_entry); } else if (!NeedQuantize(e.node, excluded_nodes)) { if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back(mirror_entry_map[e]); @@ -199,11 +204,8 @@ Graph QuantizeGraph(Graph &&src) { if (mirror_node->op() == Op::Get("_contrib_dequantize")) { mirror_node = mirror_node->inputs[0].node; } - NodeEntry mirror_entry = NodeEntry{ - mirror_node, e.index, e.version}; + NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; // for quantize node - uint32_t min_index = 1; - uint32_t max_index = 2; if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i)) { // skip non-quantized input @@ -215,19 +217,23 @@ Graph QuantizeGraph(Graph &&src) { // there is only 1min and 1max output from mirror node (which is // currently true) size_t num_outputs = mirror_node->num_outputs() - 2; - min_index = num_outputs + 2 * e.index; - max_index = num_outputs + 2 * e.index + 1; + uint32_t min_index = num_outputs + 2 * e.index; + uint32_t max_index = num_outputs + 2 * e.index + 1; + new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + } else if (use_quantized_data_layer && e.node->is_variable() && + e.node->attrs.name == "data") { + NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + "_min"); + NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + "_max"); + new_node->inputs.emplace_back(NodeEntry{min_var, 0, 0}); + new_node->inputs.emplace_back(NodeEntry{max_var, 0, 0}); } else { - CHECK(mirror_entry_map.count(e)) - << "The input is not quantize or quantized_op"; - } - if (mirror_entry_map.count(e)) { + CHECK(mirror_entry_map.count(e)) << "The input is not quantize or quantized_op"; + uint32_t min_index = 1; + uint32_t max_index = 2; auto quantize_entry = mirror_entry_map[e]; new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, min_index, 0}); new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, max_index, 0}); - } else { - new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); } } diff --git a/tests/python/train/test_dtype.py b/tests/python/train/test_dtype.py index 2e3ff06d2e18..39bfbcdeeafe 100644 --- a/tests/python/train/test_dtype.py +++ b/tests/python/train/test_dtype.py @@ -65,6 +65,30 @@ def get_iterator_uint8(kv): return (train, val) +def get_iterator_int8(kv): + data_shape = (3, 28, 28) + + train = mx.io.ImageRecordInt8Iter( + path_imgrec = "data/cifar/train.rec", + data_shape = data_shape, + batch_size = batch_size, + rand_crop = True, + rand_mirror = True, + num_parts = kv.num_workers, + part_index = kv.rank) + train = mx.io.PrefetchingIter(train) + + val = mx.io.ImageRecordInt8Iter( + path_imgrec = "data/cifar/test.rec", + rand_crop = False, + rand_mirror = False, + data_shape = data_shape, + batch_size = batch_size, + num_parts = kv.num_workers, + part_index = kv.rank) + + return (train, val) + def get_iterator_float32(kv): data_shape = (3, 28, 28) @@ -190,5 +214,10 @@ def test_cifar10(): run_cifar10(train, val, use_module=False) run_cifar10(train, val, use_module=True) + # test int8 input + (train, val) = get_iterator_int8(kv) + run_cifar10(train, val, use_module=False) + run_cifar10(train, val, use_module=True) + if __name__ == "__main__": test_cifar10() From 3b89042169b92f99cd16a6c8805750ea30999ba3 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 12 Feb 2019 13:53:50 +0800 Subject: [PATCH 2/6] fix lint --- python/mxnet/contrib/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 227dde7e7edb..13189fe0ac42 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -498,7 +498,7 @@ def quantize_model(sym, arg_params, aux_params, qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), quantized_dtype=quantized_dtype, - use_quantized_data_layer = use_quantized_data_layer) + use_quantized_data_layer=use_quantized_data_layer) th_dict = {} if calib_mode is not None and calib_mode != 'none': From ff9f38f2a05bbd9b15d95c375e183f7690dbc28c Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 12 Feb 2019 14:29:07 +0800 Subject: [PATCH 3/6] Add parameter description --- include/mxnet/c_api.h | 6 +++--- python/mxnet/contrib/quantization.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index bca97b7ee3f5..739d05f70bbb 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1565,9 +1565,9 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param excluded_symbols op names to be excluded from being quantized * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline - * \param quantized_dtype the quantized destination type for input data. - * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could. - * \param use_quantized_data_layer if use quantized data layer. + * \param quantized_dtype the quantized destination type for input data + * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could + * \param use_quantized_data_layer if true, use quantized data layer */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 13189fe0ac42..cc7f6fdc0385 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -98,6 +98,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_ avoided. quantized_dtype: str The quantized destination type for input data. + use_quantized_data_layer bool + If true, use quantized data layer. """ num_excluded_symbols = 0 if excluded_symbols is not None: @@ -475,6 +477,8 @@ def quantize_model(sym, arg_params, aux_params, The quantized destination type for input data. Currently support 'int8' , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. Default value is 'int8'. + use_quantized_data_layer bool + If true, use quantized data layer. logger : Object A logging object for printing information during the process of quantization. From cb91ee4619ef62297143625d9a395e24ce7def88 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 12 Feb 2019 16:18:34 +0800 Subject: [PATCH 4/6] Fix imagenet_inference.py --- example/quantization/imagenet_inference.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py index f53afc596798..98e03cb16897 100644 --- a/example/quantization/imagenet_inference.py +++ b/example/quantization/imagenet_inference.py @@ -226,9 +226,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, label_name=label_name, rand_crop=False, rand_mirror=False, - shuffle=True, - shuffle_chunk_seed=3982304, - seed=48564309, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, **combine_mean_std) elif data_layer_type == 'uint8': data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset, @@ -239,9 +239,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, label_name=label_name, rand_crop=False, rand_mirror=False, - shuffle=True, - shuffle_chunk_seed=3982304, - seed=48564309, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, **combine_mean_std) else: #int8 data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset, @@ -252,9 +252,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, label_name=label_name, rand_crop=False, rand_mirror=False, - shuffle=True, - shuffle_chunk_seed=3982304, - seed=48564309, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, **combine_mean_std) # loading model From 1c066573a31fcec35fcc21e4cce35ee9366416d9 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 20 Feb 2019 15:36:50 +0800 Subject: [PATCH 5/6] Allow quantize_v2 to accept int8 --- .../quantization/imagenet_gen_qsym_mkldnn.py | 8 +- include/mxnet/c_api.h | 4 +- python/mxnet/contrib/quantization.py | 15 +-- src/c_api/c_api_symbolic.cc | 4 +- .../mkldnn/mkldnn_quantize_v2-inl.h | 31 +++++- .../quantization/quantize_graph_pass.cc | 32 +++--- src/operator/quantization/quantize_v2-inl.h | 104 ++++++++++-------- src/operator/quantization/quantize_v2.cc | 6 + 8 files changed, 112 insertions(+), 92 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 72a620dc7f64..d807e7f2d19d 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -147,8 +147,6 @@ def save_params(fname, arg_params, aux_params, logger=None): help='If enabled, the quantize op will ' 'be calibrated offline if calibration mode is ' 'enabled') - parser.add_argument('--use-quantized-data-layer', type=bool, default=True, - help='If enabled, data layer will be already quantized.') args = parser.parse_args() ctx = mx.cpu(0) logging.basicConfig() @@ -275,7 +273,6 @@ def save_params(fname, arg_params, aux_params, logger=None): qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, - use_quantized_data_layer=args.use_quantized_data_layer, logger=logger) sym_name = '%s-symbol.json' % (prefix + '-quantized') else: @@ -298,10 +295,7 @@ def save_params(fname, arg_params, aux_params, logger=None): calib_mode=calib_mode, calib_data=data, num_calib_examples=num_calib_batches * batch_size, calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, - label_names=(label_name,), - use_quantized_data_layer=args.use_quantized_data_layer, - logger=logger) - + label_names=(label_name,), logger=logger) if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches elif calib_mode == 'naive': diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 8681388d1edf..01d729bedd65 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1573,14 +1573,12 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could - * \param use_quantized_data_layer if true, use quantized data layer */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, const char **excluded_symbols, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype, const bool calib_quantize, - const bool use_quantized_data_layer); + const char *quantized_dtype, const bool calib_quantize); /*! * \brief Set calibration table to node attributes in the sym diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index cc7f6fdc0385..96183bb7a172 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -80,8 +80,7 @@ def _quantize_params(qsym, params, th_dict): quantized_params[name] = ndarray.array([th_dict[output][1]]) return quantized_params -def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8', - use_quantized_data_layer=False): +def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8'): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -98,8 +97,6 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_ avoided. quantized_dtype: str The quantized destination type for input data. - use_quantized_data_layer bool - If true, use quantized data layer. """ num_excluded_symbols = 0 if excluded_symbols is not None: @@ -123,8 +120,7 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_ mx_uint(num_offline), c_array(ctypes.c_char_p, offline), c_str(quantized_dtype), - ctypes.c_bool(True), - ctypes.c_bool(use_quantized_data_layer))) + ctypes.c_bool(True))) return Symbol(out) @@ -423,7 +419,7 @@ def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, - quantized_dtype='int8', use_quantized_data_layer=False, logger=logging): + quantized_dtype='int8', logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -477,8 +473,6 @@ def quantize_model(sym, arg_params, aux_params, The quantized destination type for input data. Currently support 'int8' , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. Default value is 'int8'. - use_quantized_data_layer bool - If true, use quantized data layer. logger : Object A logging object for printing information during the process of quantization. @@ -501,8 +495,7 @@ def quantize_model(sym, arg_params, aux_params, ' expected `int8`, `uint8` or `auto`' % quantized_dtype) qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype, - use_quantized_data_layer=use_quantized_data_layer) + quantized_dtype=quantized_dtype) th_dict = {} if calib_mode is not None and calib_mode != 'none': diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 34c3ef740c77..32b63c11dd9a 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -651,8 +651,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, - const bool calib_quantize, - const bool use_quantized_data_layer) { + const bool calib_quantize) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); @@ -669,7 +668,6 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); - g.attrs["use_quantized_data_layer"] = std::make_shared(use_quantized_data_layer); g = ApplyPass(std::move(g), "QuantizeGraph"); s->outputs = g.outputs; *ret_sym_handle = s; diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e201d290e8c6..d6060e54a82c 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -123,13 +123,32 @@ static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const OpContex const std::vector& req, const std::vector& outputs) { const QuantizeV2Param& param = nnvm::get(attrs.parsed); - auto out_type = GetOutputType(param); - if (out_type == mshadow::kUint8) { - MKLDNNQuantizeComputeKer(inputs, outputs, param, req); - } else if (out_type == mshadow::kInt8) { - MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + if (inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == mshadow::kInt8) { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + *outputs[1].data().dptr() = param.min_calib_range.value(); + *outputs[2].data().dptr() = param.max_calib_range.value(); + } else { + if (inputs[0].dtype() == mshadow::kUint8) { + *outputs[1].data().dptr() = 0; + *outputs[2].data().dptr() = 255; + } else { + *outputs[1].data().dptr() = -127; + *outputs[2].data().dptr() = 127; + } + } + if (req[0] != kWriteInplace) { + const_cast(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData()); + MKLDNNStream::Get()->Submit(); + } } else { - LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; + auto out_type = GetOutputType(param); + if (out_type == mshadow::kUint8) { + MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + } else if (out_type == mshadow::kInt8) { + MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + } else { + LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; + } } } diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index c68c41f111c1..af533978a6f5 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -126,7 +126,6 @@ Graph QuantizeGraph(Graph &&src) { const auto offline_params = src.GetAttr>("offline_params"); const auto excluded_nodes = src.GetAttr>("excluded_nodes"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); - const auto use_quantized_data_layer = src.GetAttr("use_quantized_data_layer"); // mirror_map stores the mapping from the currently visited graph to the newly created quantized // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key @@ -157,10 +156,6 @@ Graph QuantizeGraph(Graph &&src) { if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i)) { new_node->inputs.emplace_back(mirror_entry); - // If network will directly accept quantized data, simply add data as input. - } else if (use_quantized_data_layer && e.node->is_variable() && - e.node->attrs.name == "data") { - new_node->inputs.emplace_back(mirror_entry); } else if (!NeedQuantize(e.node, excluded_nodes)) { if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back(mirror_entry_map[e]); @@ -204,8 +199,11 @@ Graph QuantizeGraph(Graph &&src) { if (mirror_node->op() == Op::Get("_contrib_dequantize")) { mirror_node = mirror_node->inputs[0].node; } - NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; + NodeEntry mirror_entry = NodeEntry{ + mirror_node, e.index, e.version}; // for quantize node + uint32_t min_index = 1; + uint32_t max_index = 2; if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i)) { // skip non-quantized input @@ -217,23 +215,19 @@ Graph QuantizeGraph(Graph &&src) { // there is only 1min and 1max output from mirror node (which is // currently true) size_t num_outputs = mirror_node->num_outputs() - 2; - uint32_t min_index = num_outputs + 2 * e.index; - uint32_t max_index = num_outputs + 2 * e.index + 1; - new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); - } else if (use_quantized_data_layer && e.node->is_variable() && - e.node->attrs.name == "data") { - NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + "_min"); - NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + "_max"); - new_node->inputs.emplace_back(NodeEntry{min_var, 0, 0}); - new_node->inputs.emplace_back(NodeEntry{max_var, 0, 0}); + min_index = num_outputs + 2 * e.index; + max_index = num_outputs + 2 * e.index + 1; } else { - CHECK(mirror_entry_map.count(e)) << "The input is not quantize or quantized_op"; - uint32_t min_index = 1; - uint32_t max_index = 2; + CHECK(mirror_entry_map.count(e)) + << "The input is not quantize or quantized_op"; + } + if (mirror_entry_map.count(e)) { auto quantize_entry = mirror_entry_map[e]; new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, min_index, 0}); new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, max_index, 0}); + } else { + new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); } } diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 5ae10a7e4fa8..f10458cf76fe 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -137,50 +137,67 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, Stream *s = ctx.get_stream(); const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); auto out_type = GetOutputType(param); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - if (out_type == mshadow::kUint8) { - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), - param.max_calib_range.value(), MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), - param.max_calib_range.value(), MinAbs(MaxValue(), MinValue())); + + if (inputs[0].type_flag_ == mshadow::kUint8 || inputs[0].type_flag_ == mshadow::kInt8) { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + *outputs[1].dptr() = param.min_calib_range.value(); + *outputs[2].dptr() = param.max_calib_range.value(); } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + if (inputs[0].type_flag_ == mshadow::kUint8) { + *outputs[1].dptr() = 0; + *outputs[2].dptr() = 255; + } else { + *outputs[1].dptr() = -127; + *outputs[2].dptr() = 127; + } } - } else { // model is not calibrated - TShape src_shape, dst_shape; - const size_t actual_float_size = sizeof(float); - const size_t temp_reduce_size = - ConfigReduce(s, inputs[0].shape_, TShape({1}), &src_shape, &dst_shape); - Tensor temp_space = ctx.requested[0].get_space_typed( - Shape1(2 * actual_float_size + temp_reduce_size), s); - const int dev_id = ctx.run_ctx.ctx.dev_id; - TBlob in_min_t(reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, - dev_id); - TBlob in_max_t(reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, - dev_id); - Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, - Shape1(temp_reduce_size), s); - broadcast::Reduce( - s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); - broadcast::Reduce( - s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); - if (out_type == mshadow::kUint8) { - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), - in_max_t.dptr(), MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), - in_max_t.dptr(), MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + UnaryOp::IdentityCompute(attrs, ctx, {inputs[0]}, req, outputs); + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } + } else { // model is not calibrated + TShape src_shape, dst_shape; + const size_t actual_float_size = sizeof(float); + const size_t temp_reduce_size = + ConfigReduce(s, inputs[0].shape_, TShape({1}), &src_shape, &dst_shape); + Tensor temp_space = ctx.requested[0].get_space_typed( + Shape1(2 * actual_float_size + temp_reduce_size), s); + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob in_min_t(reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, + dev_id); + TBlob in_max_t(reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, + dev_id); + Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, + Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + broadcast::Reduce( + s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } } } } @@ -201,7 +218,8 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 3U); const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); - TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32); + CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 || + in_attrs->at(0) == mshadow::kInt8); auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 21410933d35e..eaf97f61745c 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -88,6 +88,12 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr("FComputeEx", MKLDNNQuantizeV2Compute) #endif .set_attr("FCompute", QuantizeV2Compute) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; +}) +.set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ + return std::vector{true}; +}) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { From acbff96b8a6547efbe2740b56c96beae01f0dd31 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 1 Mar 2019 12:37:09 +0800 Subject: [PATCH 6/6] make float32 default --- example/quantization/imagenet_inference.py | 60 +++++++++++----------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py index 98e03cb16897..47e206303e99 100644 --- a/example/quantization/imagenet_inference.py +++ b/example/quantization/imagenet_inference.py @@ -107,15 +107,15 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, logger.info('Loading symbol from file %s' % symbol_file_path) sym = mx.sym.load(symbol_file_path) mod = mx.mod.Module(symbol=sym, context=ctx) - if data_layer_type == "float32": + if data_layer_type == "int8": dshape = mx.io.DataDesc(name='data', shape=( - batch_size,) + data_shape, dtype=np.float32) + batch_size,) + data_shape, dtype=np.int8) elif data_layer_type == 'uint8': dshape = mx.io.DataDesc(name='data', shape=( batch_size,) + data_shape, dtype=np.uint8) - else: # int8 + else: # float32 dshape = mx.io.DataDesc(name='data', shape=( - batch_size,) + data_shape, dtype=np.int8) + batch_size,) + data_shape, dtype=np.float32) mod.bind(for_training=False, inputs_need_grad=False, data_shapes=[dshape]) @@ -217,19 +217,19 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, logger.info('Dataset for inference: %s' % dataset) # creating data iterator - if data_layer_type == 'float32': - data = mx.io.ImageRecordIter(path_imgrec=dataset, - label_width=1, - preprocess_threads=data_nthreads, - batch_size=batch_size, - data_shape=data_shape, - label_name=label_name, - rand_crop=False, - rand_mirror=False, - shuffle=args.shuffle_dataset, - shuffle_chunk_seed=args.shuffle_chunk_seed, - seed=args.shuffle_seed, - **combine_mean_std) + if data_layer_type == 'int8': + data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **combine_mean_std) elif data_layer_type == 'uint8': data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset, label_width=1, @@ -243,19 +243,19 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, shuffle_chunk_seed=args.shuffle_chunk_seed, seed=args.shuffle_seed, **combine_mean_std) - else: #int8 - data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset, - label_width=1, - preprocess_threads=data_nthreads, - batch_size=batch_size, - data_shape=data_shape, - label_name=label_name, - rand_crop=False, - rand_mirror=False, - shuffle=args.shuffle_dataset, - shuffle_chunk_seed=args.shuffle_chunk_seed, - seed=args.shuffle_seed, - **combine_mean_std) + else: #float32 + data = mx.io.ImageRecordIter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **combine_mean_std) # loading model sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)