From 9ba9d4ce5f895035cc1ddacf90820883447a77ee Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 19 Dec 2018 10:20:17 +0800 Subject: [PATCH 01/38] Enable s8s8 support for MKLDNN convolution. --- .gitignore | 1 + cpp-package/include/mxnet-cpp/monitor.hpp | 6 +- .../quantization/imagenet_gen_qsym_mkldnn.py | 39 +-- example/ssd/quantization.py | 4 +- include/mxnet/c_api.h | 7 +- include/mxnet/executor.h | 2 +- include/mxnet/ndarray.h | 10 +- perl-package/AI-MXNetCAPI/mxnet.i | 5 +- python/mxnet/contrib/quantization.py | 105 +++---- python/mxnet/executor.py | 7 +- python/mxnet/monitor.py | 9 +- .../native/org_apache_mxnet_native_c_api.cc | 3 +- src/c_api/c_api_executor.cc | 5 +- src/c_api/c_api_symbolic.cc | 7 +- src/executor/graph_executor.cc | 40 ++- src/executor/graph_executor.h | 10 +- src/ndarray/ndarray.cc | 39 ++- src/operator/nn/mkldnn/mkldnn_base-inl.h | 24 +- src/operator/nn/mkldnn/mkldnn_base.cc | 61 ++-- .../nn/mkldnn/mkldnn_convolution-inl.h | 21 +- src/operator/nn/mkldnn/mkldnn_convolution.cc | 192 ++++++------ .../mkldnn/mkldnn_dequantize-inl.h | 3 + .../mkldnn/mkldnn_quantize_v2-inl.h | 133 +++++++++ .../quantization/quantization_utils.h | 15 + .../quantization/quantize_graph_pass.cc | 126 ++++---- src/operator/quantization/quantize_v2-inl.h | 226 ++++++++++++++ src/operator/quantization/quantize_v2.cc | 97 ++++++ src/operator/quantization/quantize_v2.cu | 34 +++ src/operator/quantization/requantize-inl.h | 14 - src/operator/subgraph/mkldnn/mkldnn_conv.cc | 282 +++++++++--------- .../subgraph/mkldnn/mkldnn_conv_property.cc | 31 +- tests/python/mkl/test_subgraph.py | 29 +- tests/python/unittest/test_operator.py | 47 ++- 33 files changed, 1121 insertions(+), 513 deletions(-) create mode 100644 src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h create mode 100644 src/operator/quantization/quantize_v2-inl.h create mode 100644 src/operator/quantization/quantize_v2.cc create mode 100644 src/operator/quantization/quantize_v2.cu diff --git a/.gitignore b/.gitignore index 7eb8e7d6e777..9a145f5b8afc 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,7 @@ __pycache__ build cmake-build* data +model recommonmark deps diff --git a/cpp-package/include/mxnet-cpp/monitor.hpp b/cpp-package/include/mxnet-cpp/monitor.hpp index f3584e2e8092..d1e548fbc05a 100644 --- a/cpp-package/include/mxnet-cpp/monitor.hpp +++ b/cpp-package/include/mxnet-cpp/monitor.hpp @@ -43,10 +43,10 @@ inline Monitor::Monitor(int interval, std::regex pattern, StatFunc stat_func) : interval(interval), pattern(pattern), stat_func(stat_func), step(0) { } -inline void Monitor::install(Executor *exe) { +inline void Monitor::install(Executor *exe, bool monitor_all = false) { MXExecutorSetMonitorCallback(exe->handle_, - static_cast(&Monitor::executor_callback), - this); + static_cast(&Monitor::executor_callback), + this, monitor_all); exes.push_back(exe); } diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index c38019fbe7b9..c1715cb31024 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -55,24 +55,24 @@ def convert_from_gluon(model_name, image_shape, classes=1000, logger=None): symnet = mx.symbol.load_json(y.tojson()) params = net.collect_params() args = {} - auxs = {} + auxs = {} for param in params.values(): v = param._reduce() k = param.name if 'running' in k: auxs[k] = v else: - args[k] = v + args[k] = v mod = mx.mod.Module(symbol=symnet, context=mx.cpu(), label_names = ['softmax_label']) - mod.bind(for_training=False, - data_shapes=[('data', (1,) + + mod.bind(for_training=False, + data_shapes=[('data', (1,) + tuple([int(i) for i in image_shape.split(',')]))]) mod.set_params(arg_params=args, aux_params=auxs) dst_dir = os.path.join(dir_path, 'model') prefix = os.path.join(dir_path, 'model', model_name) if not os.path.isdir(dst_dir): - os.mkdir(dst_dir) + os.mkdir(dst_dir) mod.save_checkpoint(prefix, 0) return prefix @@ -104,7 +104,7 @@ def save_params(fname, arg_params, aux_params, logger=None): 'you can set to custom to load your pre-trained model.') parser.add_argument('--use-gluon-model', type=bool, default=False, help='If enabled, will download pretrained model from Gluon-CV ' - 'and convert to symbolic model ') + 'and convert to symbolic model ') parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--label-name', type=str, default='softmax_label') parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec', @@ -114,7 +114,7 @@ def save_params(fname, arg_params, aux_params, logger=None): help='number of threads for data decoding') parser.add_argument('--num-calib-batches', type=int, default=10, help='number of batches for calibration') - parser.add_argument('--exclude-first-conv', action='store_true', default=True, + parser.add_argument('--exclude-first-conv', action='store_true', default=False, help='excluding quantizing the first conv layer since the' ' input data may have negative value which doesn\'t support at moment' ) parser.add_argument('--shuffle-dataset', action='store_true', default=True, @@ -140,8 +140,8 @@ def save_params(fname, arg_params, aux_params, logger=None): ' thresholds. This mode is expected to produce the best inference accuracy of all three' ' kinds of quantized models if the calibration dataset is representative enough of the' ' inference dataset.') - parser.add_argument('--quantized-dtype', type=str, default='uint8', - choices=['int8', 'uint8'], + parser.add_argument('--quantized-dtype', type=str, default='auto', + choices=['auto', 'int8', 'uint8'], help='quantization destination data type for input data') parser.add_argument('--enable-calib-quantize', type=bool, default=True, help='If enabled, the quantize op will ' @@ -203,35 +203,30 @@ def save_params(fname, arg_params, aux_params, logger=None): if args.model == 'imagenet1k-resnet-152': rgb_mean = '0,0,0' rgb_std = '1,1,1' - calib_layer = lambda name: name.endswith('_output') - excluded_sym_names += ['flatten0', 'fc1', 'pooling0'] + excluded_sym_names += ['flatten0', 'fc1'] if exclude_first_conv: excluded_sym_names += ['conv0'] elif args.model == 'imagenet1k-inception-bn': rgb_mean = '123.68,116.779,103.939' rgb_std = '1,1,1' - calib_layer = lambda name: name.endswith('_output') excluded_sym_names += ['flatten', 'fc1'] if exclude_first_conv: excluded_sym_names += ['conv_1'] elif args.model in ['resnet50_v1', 'resnet101_v1']: rgb_mean = '123.68,116.779,103.939' rgb_std = '58.393, 57.12, 57.375' - calib_layer = lambda name: name.endswith('_output') - excluded_sym_names += ['resnetv10_dense0_fwd', 'resnetv10_pool0_fwd'] + excluded_sym_names += ['resnetv10_dense0_fwd'] if exclude_first_conv: excluded_sym_names += ['resnetv10_conv0_fwd'] elif args.model == 'squeezenet1.0': rgb_mean = '123.68,116.779,103.939' rgb_std = '58.393, 57.12, 57.375' - calib_layer = lambda name: name.endswith('_output') excluded_sym_names += ['squeezenet0_flatten0_flatten0'] if exclude_first_conv: excluded_sym_names += ['squeezenet0_conv0_fwd'] elif args.model == 'mobilenet1.0': rgb_mean = '123.68,116.779,103.939' rgb_std = '58.393, 57.12, 57.375' - calib_layer = lambda name: name.endswith('_output') excluded_sym_names += ['mobilenet0_flatten0_flatten0', 'mobilenet0_dense0_fwd', 'mobilenet0_pool0_fwd'] @@ -240,16 +235,13 @@ def save_params(fname, arg_params, aux_params, logger=None): elif args.model == 'inceptionv3': rgb_mean = '123.68,116.779,103.939' rgb_std = '58.393, 57.12, 57.375' - calib_layer = lambda name: name.endswith('_output') - excluded_sym_names += ['inception30_dense0_fwd', - 'inception30_pool0_fwd'] + excluded_sym_names += ['inception30_dense0_fwd'] if exclude_first_conv: excluded_sym_names += ['inception30_conv0_fwd'] elif args.model == 'custom': # add rgb mean/std of your model. rgb_mean = '0,0,0' rgb_std = '0,0,0' - calib_layer = lambda name: name.endswith('_output') # add layer names you donnot want to quantize. # add conv/pool layer names that has negative inputs # since Intel MKL-DNN only support uint8 quantization temporary. @@ -272,7 +264,7 @@ def save_params(fname, arg_params, aux_params, logger=None): mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} logger.info('rgb_std = %s' % rgb_std) rgb_std = [float(i) for i in rgb_std.split(',')] - std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]} + std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]} if calib_mode == 'none': logger.info('Quantizing FP32 model %s' % args.model) @@ -301,9 +293,8 @@ def save_params(fname, arg_params, aux_params, logger=None): ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, calib_data=data, num_calib_examples=num_calib_batches * batch_size, - calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, - label_names=(label_name,), calib_quantize_op = True, - logger=logger) + calib_layer=None, quantized_dtype=args.quantized_dtype, + label_names=(label_name,), logger=logger) if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches elif calib_mode == 'naive': diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py index 231cc99f93bc..4ed28dd03c2f 100644 --- a/example/ssd/quantization.py +++ b/example/ssd/quantization.py @@ -157,9 +157,7 @@ def save_params(fname, arg_params, aux_params, logger=None): calib_mode=calib_mode, calib_data=eval_iter, num_calib_examples=num_calib_batches * batch_size, calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, - label_names=(label_name,), - calib_quantize_op = True, - logger=logger) + label_names=(label_name,), logger=logger) sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300') param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch) qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index e9f1e2d6cccc..1c7575ccd688 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1556,13 +1556,12 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data. - * \param calib_quantize whether calibrate quantize op with offline calibration data. */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, const char **excluded_symbols, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype, const bool calib_quantize); + const char *quantized_dtype); /*! * \brief Set calibration table to node attributes in the sym @@ -1833,10 +1832,12 @@ MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, /*! * \brief set a call back to notify the completion of operation + * \param monitor_all If true, monitor both input and output, otherwise monitor output only. */ MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, - void* callback_handle); + void* callback_handle, + bool monitor_all); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 0ab04b86a0a1..c3f2459c0f59 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -174,7 +174,7 @@ class Executor { /*! * \brief Install a callback to notify the completion of operation. */ - virtual void SetMonitorCallback(const MonitorCallback& callback) {} + virtual void SetMonitorCallback(const MonitorCallback& callback, bool monitor_input) {} }; // class executor } // namespace mxnet #endif // MXNET_EXECUTOR_H_ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 4ba13ca6498a..5de42e19a657 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -694,9 +694,13 @@ class NDArray { /* * Create NDArray from mkldnn memory. * mkldnn_mem The mkldnn memory to be managed. - * static_data If true, mkldnn memory won't be freed on destruction. */ - explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true); + explicit NDArray(const std::shared_ptr &mkldnn_mem); + /* + * Create NDArray from mkldnn memory descriptor. + * mem_pd The mkldnn memory descriptor to be created. + */ + explicit NDArray(mkldnn::memory::primitive_desc mem_pd); /* * Test if the data is stored in one of special MKLDNN format. */ @@ -776,7 +780,7 @@ class NDArray { /*! * \ Fix mkldnn memory descriptor mismatch from NDArray. */ - void UpdateMKLDNNMemDesc(); + void UpdateMKLDNNMemDesc(mkldnn::memory::format format); #endif /*! diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index b1907f5cd7ec..ca6623572dfb 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -1614,10 +1614,12 @@ int MXExecutorReshape(int partial_shaping, /*! * \brief set a call back to notify the completion of operation + * \param monitor_all If true, monitor both input and output, otherwise monitor output only. */ int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, - void* callback_handle); + void* callback_handle, + bool monitor_all); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- @@ -2167,4 +2169,3 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** cuda_kernel_ mx_uint grid_dim_z, mx_uint block_dim_x, mx_uint block_dim_y, mx_uint block_dim_z, mx_uint shared_mem); - diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 61ad8a3ec704..0f32cbc82f37 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -26,6 +26,7 @@ import ctypes import logging import os +import sys import numpy as np from ..base import _LIB, check_call, py_str from ..base import c_array, c_str, mx_uint, c_str_array @@ -80,8 +81,7 @@ def _quantize_params(qsym, params, th_dict): quantized_params[name] = ndarray.array([th_dict[output][1]]) return quantized_params -def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, - quantized_dtype='int8', calib_quantize_op=False): +def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8'): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -98,8 +98,6 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, avoided. quantized_dtype: str The quantized destination type for input data. - calib_quantize_op : bool - Whether perform offline calibration for quantize op. """ num_excluded_symbols = 0 if excluded_symbols is not None: @@ -122,8 +120,7 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, c_str_array(excluded_symbols), mx_uint(num_offline), c_array(ctypes.c_char_p, offline), - c_str(quantized_dtype), - ctypes.c_bool(calib_quantize_op))) + c_str(quantized_dtype))) return Symbol(out) @@ -139,18 +136,20 @@ def __init__(self, include_layer=None, logger=None): def collect(self, name, arr): """Callback function for collecting layer output NDArrays.""" - name = py_str(name) - if self.include_layer is not None and not self.include_layer(name): - return - handle = ctypes.cast(arr, NDArrayHandle) - arr = NDArray(handle, writable=False).copyto(cpu()) - if self.logger is not None: - self.logger.info("Collecting layer %s output of shape %s" % (name, arr.shape)) - if name in self.nd_dict: - self.nd_dict[name].append(arr) - else: - self.nd_dict[name] = [arr] - + try: + name = py_str(name) + if self.include_layer is not None and not self.include_layer(name): + return + handle = ctypes.cast(arr, NDArrayHandle) + arr = NDArray(handle, writable=False).copyto(cpu()) + if self.logger is not None: + self.logger.info("Collecting layer %s output of shape %s" % (name, arr.shape)) + if name in self.nd_dict: + self.nd_dict[name].append(arr) + else: + self.nd_dict[name] = [arr] + except KeyboardInterrupt: + sys.exit(1) class _LayerOutputMinMaxCollector(object): """Saves layer output min and max values in a dict with layer names as keys. @@ -163,23 +162,25 @@ def __init__(self, include_layer=None, logger=None): def collect(self, name, arr): """Callback function for collecting min and max values from an NDArray.""" - name = py_str(name) - if self.include_layer is not None and not self.include_layer(name): - return - handle = ctypes.cast(arr, NDArrayHandle) - arr = NDArray(handle, writable=False) - min_range = ndarray.min(arr).asscalar() - max_range = ndarray.max(arr).asscalar() - if name in self.min_max_dict: - cur_min_max = self.min_max_dict[name] - self.min_max_dict[name] = (min(cur_min_max[0], min_range), - max(cur_min_max[1], max_range)) - else: - self.min_max_dict[name] = (min_range, max_range) - if self.logger is not None: - self.logger.info("Collecting layer %s output min_range=%f, max_range=%f" - % (name, min_range, max_range)) - + try: + name = py_str(name) + if self.include_layer is not None and not self.include_layer(name): + return + handle = ctypes.cast(arr, NDArrayHandle) + arr = NDArray(handle, writable=False) + min_range = ndarray.min(arr).asscalar() + max_range = ndarray.max(arr).asscalar() + if name in self.min_max_dict: + cur_min_max = self.min_max_dict[name] + self.min_max_dict[name] = (min(cur_min_max[0], min_range), + max(cur_min_max[1], max_range)) + else: + self.min_max_dict[name] = (min_range, max_range) + if self.logger is not None: + self.logger.info("Collecting layer %s min_range=%f, max_range=%f" + % (name, min_range, max_range)) + except KeyboardInterrupt: + sys.exit(1) def _calibrate_quantized_sym(qsym, th_dict): """Given a dictionary containing the thresholds for quantizing the layers, @@ -210,7 +211,7 @@ def _collect_layer_statistics(mod, data, collector, max_num_examples=None, logge if not isinstance(data, DataIter): raise ValueError('Only supports data as a type of DataIter, while received type %s' % str(type(data))) - mod._exec_group.execs[0].set_monitor_callback(collector.collect) + mod._exec_group.execs[0].set_monitor_callback(collector.collect, monitor_all=True) num_batches = 0 num_examples = 0 for batch in data: @@ -265,6 +266,9 @@ def _smooth_distribution(p, eps=0.0001): # pylint: disable=line-too-long def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): """Given a dataset, find the optimal threshold for quantizing it. + The reference distribution is `q`, and the candidate distribution is `p`. + `q` is a truncated version of the original distribution. + Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ if isinstance(arr, NDArray): @@ -286,12 +290,12 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): min_val = np.min(arr) max_val = np.max(arr) th = max(abs(min_val), abs(max_val)) + if min_val >= 0: + num_quantized_bins = (num_quantized_bins // 2) * 4 + 1 hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) zero_bin_idx = num_bins // 2 num_half_quantized_bins = num_quantized_bins // 2 - assert np.allclose(hist_edges[zero_bin_idx] + hist_edges[zero_bin_idx + 1], - 0, rtol=1e-5, atol=1e-7) thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) divergence = np.zeros_like(thresholds) @@ -315,10 +319,10 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): right_outlier_count = np.sum(hist[p_bin_idx_stop:]) p[-1] += right_outlier_count # is_nonzeros[k] indicates whether hist[k] is nonzero - is_nonzeros = (sliced_nd_hist != 0).astype(np.int32) + is_nonzeros = (p != 0).astype(np.int32) # calculate how many bins should be merged to generate quantized distribution q - num_merged_bins = p.size // num_quantized_bins + num_merged_bins = sliced_nd_hist.size // num_quantized_bins # merge hist into num_quantized_bins bins for j in range(num_quantized_bins): start = j * num_merged_bins @@ -326,17 +330,17 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): quantized_bins[j] = sliced_nd_hist[start:stop].sum() quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() # expand quantized_bins into p.size bins - q = np.zeros(p.size, dtype=np.float32) + q = np.zeros(sliced_nd_hist.size, dtype=np.float32) for j in range(num_quantized_bins): start = j * num_merged_bins if j == num_quantized_bins - 1: - stop = -1 + stop = len(is_nonzeros) else: stop = start + num_merged_bins norm = is_nonzeros[start:stop].sum() if norm != 0: q[start:stop] = float(quantized_bins[j]) / float(norm) - q[sliced_nd_hist == 0] = 0 + q[p == 0] = 0 p = _smooth_distribution(p) # There is a chance that q is an invalid probability distribution. try: @@ -344,7 +348,6 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): except ValueError: divergence[i - num_half_quantized_bins] = float("inf") divergence[i - num_half_quantized_bins] = stats.entropy(p, q) - quantized_bins[:] = 0 min_divergence_idx = np.argmin(divergence) min_divergence = divergence[min_divergence_idx] @@ -424,7 +427,7 @@ def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, - quantized_dtype='int8', calib_quantize_op=False, logger=logging): + quantized_dtype='int8', logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -476,9 +479,8 @@ def quantize_model(sym, arg_params, aux_params, all the layers' outputs that need requantization will be collected. quantized_dtype : str The quantized destination type for input data. Currently support 'int8' - and 'uint8', default value is 'int8'. - calib_quantize_op: bool - Whether calibrate quantize op with its input calibration data. The quantize op's input should be in calib_layer + , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. logger : Object A logging object for printing information during the process of quantization. @@ -496,13 +498,12 @@ def quantize_model(sym, arg_params, aux_params, ' while received type %s' % str(type(excluded_sym_names))) logger.info('Quantizing symbol') - if quantized_dtype not in ('int8', 'uint8'): + if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' - ' expected `int8` or `uint8`' % quantized_dtype) + ' expected `int8`, `uint8` or `auto`' % quantized_dtype) qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype, - calib_quantize_op=calib_quantize_op) + quantized_dtype=quantized_dtype) th_dict = {} if calib_mode is not None and calib_mode != 'none': diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index fcd5406236e9..ddb2dab1098e 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -234,13 +234,15 @@ def backward(self, out_grads=None, is_train=True): ndarray, ctypes.c_int(is_train))) - def set_monitor_callback(self, callback): + def set_monitor_callback(self, callback, monitor_all=False): """Install callback for monitor. Parameters ---------- callback : function Takes a string and an NDArrayHandle. + monitor_all : bool, default False + If true, monitor both input and output, otherwise monitor output only. Examples -------- @@ -254,7 +256,8 @@ def set_monitor_callback(self, callback): check_call(_LIB.MXExecutorSetMonitorCallback( self.handle, self._monitor_callback, - None)) + None, + ctypes.c_int(monitor_all))) @property def arg_dict(self): diff --git a/python/mxnet/monitor.py b/python/mxnet/monitor.py index e3185a1281af..2e10708e72f4 100644 --- a/python/mxnet/monitor.py +++ b/python/mxnet/monitor.py @@ -31,7 +31,7 @@ class Monitor(object): - """Monitor outputs, weights, and gradients for debugging. + """Monitor inputs, outputs, weights, and gradients for debugging. Parameters ---------- @@ -46,8 +46,10 @@ class Monitor(object): Only tensors with names that match `name_pattern` will be included. For example, '.*weight|.*output' will print all weights and outputs and '.*backward.*' will print all gradients. + monitor_all : bool, default False + If true, monitor both input and output, otherwise monitor output only. """ - def __init__(self, interval, stat_func=None, pattern='.*', sort=False): + def __init__(self, interval, stat_func=None, pattern='.*', sort=False, monitor_all=False): if stat_func is None: def asum_stat(x): """returns |x|/size(x), async execution.""" @@ -61,6 +63,7 @@ def asum_stat(x): self.exes = [] self.re_prog = re.compile(pattern) self.sort = sort + self.monitor_all = monitor_all def stat_helper(name, array): """wrapper for executor callback""" array = ctypes.cast(array, NDArrayHandle) @@ -79,7 +82,7 @@ def install(self, exe): exe : mx.executor.Executor The Executor (returned by symbol.bind) to install to. """ - exe.set_monitor_callback(self.stat_helper) + exe.set_monitor_callback(self.stat_helper, self.monitor_all) self.exes.append(exe) def tic(self): diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 17d166eac345..663d3a4142fa 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -915,7 +915,8 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallbac jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj); return MXExecutorSetMonitorCallback(reinterpret_cast(executorPtr), ExecutorMonitorCallbackFunc, - reinterpret_cast(callbackFuncObjGlb)); + reinterpret_cast(callbackFuncObjGlb), + false); } JNIEXPORT jstring JNICALL Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index e2e53c7261fa..b15f2d508644 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -649,7 +649,8 @@ int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, - void* callback_handle) { + void* callback_handle, + bool monitor_all) { API_BEGIN(); ExecutorMonitorCallback callback_temp = callback; void* callback_handle_temp = callback_handle; @@ -658,6 +659,6 @@ int MXExecutorSetMonitorCallback(ExecutorHandle handle, callback_temp(name, handle, callback_handle_temp); }; Executor *exec = static_cast(handle); - exec->SetMonitorCallback(clbk); + exec->SetMonitorCallback(clbk, monitor_all); API_END(); } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 73a8a7ca6f86..0a49b88e5429 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -650,8 +650,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const char **excluded_op_names, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype, - const bool calib_quantize) { + const char *quantized_dtype) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); @@ -668,7 +667,6 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); - g.attrs["calib_quantize"] = std::make_shared(calib_quantize); g = ApplyPass(std::move(g), "QuantizeGraph"); s->outputs = g.outputs; *ret_sym_handle = s; @@ -685,10 +683,9 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, API_BEGIN(); nnvm::Symbol* sym = static_cast(qsym_handle); nnvm::Graph g = Symbol2Graph(*sym); - const std::string prefix = "quantized_"; std::unordered_map> calib_table; for (size_t i = 0; i < num_layers; ++i) { - calib_table.emplace(prefix+layer_names[i], std::make_pair(min_ranges[i], max_ranges[i])); + calib_table.emplace(layer_names[i], std::make_pair(min_ranges[i], max_ranges[i])); } g.attrs["calib_table"] = std::make_shared(std::move(calib_table)); g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph"); diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index d866ad135573..8302dc133c64 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -101,9 +101,10 @@ void GraphExecutor::Print(std::ostream &os) const { // NOLINT(*) os << "Total " << 11 << " TempSpace resource requested\n"; } -void GraphExecutor::SetMonitorCallback(const MonitorCallback& callback) { +void GraphExecutor::SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) { CHECK(callback) << "invalid callback"; monitor_callback_ = callback; + monitor_all_ = monitor_all; } const std::vector& GraphExecutor::outputs() const { @@ -1291,7 +1292,36 @@ void GraphExecutor::BulkInferenceOpSegs() { } } -void GraphExecutor::ExecuteMonCallback(size_t nid) { +void GraphExecutor::ExecuteMonInputCallback(size_t nid) { + static const auto& flist_inputs = + nnvm::Op::GetAttr("FListInputNames"); + const auto& idx = graph_.indexed_graph(); + std::vector input_names; + OpNode& opnode = op_nodes_[nid]; + const auto& inode = idx[nid]; + const auto& node = idx[nid].source; + if (flist_inputs.count(node->op())) { + input_names = flist_inputs[node->op()](node->attrs); + } else { + for (size_t i = 0; i < node->num_inputs(); ++i) { + input_names.emplace_back("input" + std::to_string(i)); + } + } + CHECK_EQ(opnode.exec->in_array.size(), input_names.size()); + for (size_t i = 0; i < opnode.exec->in_array.size(); ++i) { + if (node->inputs[i].node->is_variable()) { + // Monitor variable + NDArray *cpy = new NDArray(opnode.exec->in_array[i]); + std::string name = node->inputs[i].node->attrs.name; + this->monitor_callback_(name.c_str(), reinterpret_cast(cpy)); + } + NDArray *cpy = new NDArray(opnode.exec->in_array[i]); + std::string name = inode.source->attrs.name + "_" + input_names[i]; + this->monitor_callback_(name.c_str(), reinterpret_cast(cpy)); + } +} + +void GraphExecutor::ExecuteMonOutputCallback(size_t nid) { static const auto& flist_outputs = nnvm::Op::GetAttr("FListOutputNames"); const auto& idx = graph_.indexed_graph(); @@ -1341,6 +1371,10 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { if (inode.source->is_variable()) continue; OpNode& opnode = op_nodes_[nid]; if (op_nodes_[nid].skip_exec_node) continue; + // Monitor callbacks + if (monitor_callback_ && monitor_all_) { + ExecuteMonInputCallback(nid); + } opnode.exec->op_ctx.is_train = is_train; opnode.exec->op_ctx.need_grad = need_grad_; if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) { @@ -1359,7 +1393,7 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { } // Monitor callbacks if (monitor_callback_) { - ExecuteMonCallback(nid); + ExecuteMonOutputCallback(nid); } } } diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index f5f032e3f2e6..722714716aa4 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -68,7 +68,7 @@ class GraphExecutor : public Executor { const std::unordered_map& arg_grad_map() const override; const std::unordered_map& aux_state_map() const override; void Print(std::ostream &os) const override; // NOLINT(*) - void SetMonitorCallback(const MonitorCallback& callback) override; + void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) override; // Initialize the rest of attributes // after setting up arguments. void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, @@ -209,8 +209,10 @@ class GraphExecutor : public Executor { * ret.opr Can be nullptr if creation failed. */ CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end); - // run the monitor callback for node `nid` - void ExecuteMonCallback(size_t nid); + // run the monitor callback for input of node `nid` + void ExecuteMonInputCallback(size_t nid); + // run the monitor callback for output of node `nid` + void ExecuteMonOutputCallback(size_t nid); // peform bulking and segmentation on an inference graph void BulkInferenceOpSegs(); // perform bulking and segmentation on a training graph @@ -250,6 +252,8 @@ class GraphExecutor : public Executor { size_t num_forward_nodes_{0}; // monitor call back std::function monitor_callback_{nullptr}; + // monitor both input and output from monitor call back + bool monitor_all_{false}; // whether to enable bulk execution bool prefer_bulk_execution_; // cached segment operator diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 081d4e759323..a1c3497820e3 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { #if MXNET_USE_MKLDNN == 1 -NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data) +NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { - auto mem_pd = mkldnn_mem->get_primitive_desc(); auto mem_desc = mem_pd.desc(); shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); dtype_ = get_mxnet_type(mem_desc.data.data_type); - auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_); - ptr_ = std::make_shared(data, 0); + ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); + ptr_->CheckAndAlloc(mem_pd.get_size()); ptr_->mkl_mem_ = std::make_shared(mem_pd, ptr_->shandle.dptr); - ptr_->static_data = static_data; +} + +NDArray::NDArray(const std::shared_ptr &mkldnn_mem) + : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + auto mem_pd = mkldnn_mem->get_primitive_desc(); + auto mem_desc = mem_pd.desc(); + shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); + dtype_ = get_mxnet_type(mem_desc.data.data_type); + ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); + ptr_->shandle.dptr = mkldnn_mem->get_data_handle(); + ptr_->shandle.size = mem_pd.get_size(); + ptr_->delay_alloc = false; + ptr_->mkl_mem_ = std::make_shared(mkldnn_mem); + ptr_->static_data = true; } NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const { @@ -717,19 +729,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & return ptr_->mkl_mem_->GetRaw(); } -void NDArray::UpdateMKLDNNMemDesc() { +void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) { const mkldnn::memory *mem = GetMKLDNNData(); auto mem_desc = mem->get_primitive_desc().desc(); auto this_dtype = get_mkldnn_type(dtype()); - if (this_dtype != mem_desc.data.data_type) { - mkldnn::memory::desc data_md( - mkldnn::memory::dims(mem_desc.data.dims, - mem_desc.data.dims + mem_desc.data.ndims), - this_dtype, static_cast(mem_desc.data.format)); - mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); - ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); - } + mkldnn::memory::desc data_md( + mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), + this_dtype, format); + mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); + ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); } #endif diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 17e74094c2bb..660a27d8be61 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -189,6 +189,9 @@ static int GetTypeSize(int dtype) { } static inline size_t GetArraySize(const NDArray &arr) { + if (arr.IsMKLDNNData()) { + return arr.GetMKLDNNData()->get_primitive_desc().get_size(); + } return arr.shape().Size() * GetTypeSize(arr.dtype()); } @@ -237,21 +240,20 @@ static inline size_t GetMemDescSize(const mkldnn::memory::desc &md) { return ret; } -inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int ndim) { +inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1) { + int ndim = arr.shape().ndim(); mkldnn::memory::dims dims(ndim); + dtype = (dtype == -1) ? arr.dtype() : dtype; for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; - return mkldnn::memory::desc{dims, get_mkldnn_type(arr.dtype()), - mkldnn::memory::format::any}; -} - -inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr) { - return GetMemDesc(arr, arr.shape().ndim()); + return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format::any}; } inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, - int num_groups) { + int num_groups, + bool quantized = false) { + int dtype = quantized ? mshadow::kInt8 : arr.dtype(); if (num_groups == 1) { - return GetMemDesc(arr); + return GetMemDesc(arr, dtype); } else { CHECK_EQ(arr.shape().ndim(), 4U); mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups, @@ -259,7 +261,7 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, static_cast(arr.shape()[1]), static_cast(arr.shape()[2]), static_cast(arr.shape()[3])}; - return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()), + return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format::any}; } } @@ -437,6 +439,8 @@ static inline void CreateDefaultInputs(const std::vector &arrs, } } +const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups); + const mkldnn::memory *GetWeights(const NDArray &arr, const mkldnn::memory::primitive_desc &target_pd, int num_groups); diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 5db51817db9d..92e697d35ea5 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -228,51 +228,44 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { } } -const mkldnn::memory *GetWeights(const NDArray &arr, - const mkldnn::memory::primitive_desc &target_pd, - int num_groups) { - const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd); - // If the weight array already uses the target layout, simply return it - // directly. - if (mem) - return mem; - +const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype()); + const mkldnn::memory *mem = nullptr; auto engine = CpuEngine::Get()->get_engine(); if (arr.shape().ndim() == 2) { - mkldnn::memory::dims tz = mkldnn::memory::dims{ - static_cast(arr.shape()[0]), static_cast(arr.shape()[1])}; - mkldnn::memory::desc md = - mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi}; - mkldnn::memory::primitive_desc pd = - mkldnn::memory::primitive_desc{md, engine}; + mkldnn::memory::dims tz = + mkldnn::memory::dims{static_cast(arr.shape()[0]), static_cast(arr.shape()[1])}; + mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi}; + mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, engine}; mem = arr.GetMKLDNNData(pd); } else if (arr.shape().ndim() == 4 && num_groups == 1) { - mkldnn::memory::dims tz = mkldnn::memory::dims{ - static_cast(arr.shape()[0]), static_cast(arr.shape()[1]), - static_cast(arr.shape()[2]), static_cast(arr.shape()[3])}; - mkldnn::memory::desc md = - mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw}; - mkldnn::memory::primitive_desc pd = - mkldnn::memory::primitive_desc{md, engine}; + mkldnn::memory::dims tz = + mkldnn::memory::dims{static_cast(arr.shape()[0]), static_cast(arr.shape()[1]), + static_cast(arr.shape()[2]), static_cast(arr.shape()[3])}; + mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw}; + mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, engine}; mem = arr.GetMKLDNNData(pd); } else if (arr.shape().ndim() == 4) { - mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups, - static_cast(arr.shape()[0] / num_groups), - static_cast(arr.shape()[1]), - static_cast(arr.shape()[2]), - static_cast(arr.shape()[3])}; - mkldnn::memory::desc md = - mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw}; - mkldnn::memory::primitive_desc pd = - mkldnn::memory::primitive_desc{md, engine}; + mkldnn::memory::dims tz = mkldnn::memory::dims{ + num_groups, static_cast(arr.shape()[0] / num_groups), static_cast(arr.shape()[1]), + static_cast(arr.shape()[2]), static_cast(arr.shape()[3])}; + mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw}; + mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, engine}; mem = arr.GetMKLDNNData(pd); } else { LOG(FATAL) << "The weight array has an unsupported number of dimensions"; - return nullptr; } - if (mem == nullptr) - mem = arr.GetMKLDNNDataReorder(target_pd); + return mem; +} + +const mkldnn::memory *GetWeights(const NDArray &arr, + const mkldnn::memory::primitive_desc &target_pd, int num_groups) { + const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd); + // If the weight array already uses the target layout, simply return it + // directly. + if (mem) return mem; + mem = GetWeights(arr, num_groups); + if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_pd); if (mem->get_primitive_desc() == target_pd) return mem; auto ret = TmpMemMgr::Get()->Alloc(target_pd); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 971c66ad9dd2..a27dced910cb 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -85,23 +85,28 @@ static inline bool IsOutputUInt8(const MKLDNNConvParam &mkldnn_param) { mkldnn_param.with_postsum_relu; } -mkldnn::convolution_forward::primitive_desc -GetConvFwdImpl(const MKLDNNConvFullParam ¶m, const bool is_train, - const NDArray &data, const NDArray &weights, const NDArray *bias, - const NDArray &output); +mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, + const bool is_train, + const NDArray &data, + const NDArray &weights, + const NDArray *bias, + const NDArray &output); class MKLDNNConvForward { public: mkldnn::convolution_forward::primitive_desc fwd_pd; - MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, - const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output) - : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) {} + MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, + const NDArray &weights, const NDArray *bias, const NDArray &output); void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, const mkldnn::memory *bias, const mkldnn::memory &output); + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { + this->data_->set_data_handle(data.get_data_handle()); + this->out_->set_data_handle(output.get_data_handle()); + } + const mkldnn::convolution_forward &GetFwd() const { return *fwd_; } diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index dd1f3ec07d70..955dfcf5d71e 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -42,18 +42,12 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { return SupportMKLDNNQuantize(input.dtype()) && input.shape().ndim() == 4; } -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const MKLDNNConvFullParam ¶m, const bool is_train, - const NDArray &data, const NDArray &weights, const NDArray *bias, - const NDArray &output) { - auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.conv_param.num_group); - auto out_md = GetMemDesc(output); +static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( + const MKLDNNConvFullParam ¶m, const bool is_train, const mkldnn::memory::desc &data_md, + const mkldnn::memory::desc &weight_md, const mkldnn::memory::desc *bias_md, + const mkldnn::memory::desc &out_md) { auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.conv_param.stride.ndim(), 2U); - CHECK_GE(param.conv_param.pad.ndim(), 2U); - CHECK_GE(param.conv_param.dilate.ndim(), 2U); + auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; mkldnn::memory::dims strides{0, 0}; strides[0] = param.conv_param.stride[0]; strides[1] = param.conv_param.stride[1]; @@ -63,18 +57,18 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( mkldnn::primitive_attr attr; mkldnn::post_ops ops; if (param.mkldnn_param.with_relu) { - float scale = 1.0f; // for fp32, scale is 1. - float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. - float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. ops.append_eltwise(scale, eltwise_relu, alpha, beta); } if (param.mkldnn_param.with_sum) { ops.append_sum(param.sum_scale); } if (param.mkldnn_param.with_postsum_relu) { - float scale = 1.0f; // for fp32, scale is 1. - float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. - float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. ops.append_eltwise(scale, eltwise_relu, alpha, beta); } attr.set_post_ops(ops); @@ -85,62 +79,67 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( attr.set_int_output_round_mode(round_nearest); } - // MKL-DNN introduced padded formats since 0.15 which require more memory - // for computation compared with the actual tensor size. Currently, MKL-DNN - // operators are still reusing those memory from memory planning and the - // memory size may smaller than what MKL-DNN kernels require. So here we need - // select suboptimal kernel for computation according to tensor sizes. - if (param.conv_param.dilate.ndim() == 0 && bias == nullptr) { - mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine); - while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + if (param.conv_param.dilate.ndim() == 0 && bias_md == nullptr) { + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else if (param.conv_param.dilate.ndim() == 0) { - auto bias_md = GetMemDesc(*bias); - mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, - data_md, weight_md, bias_md, out_md, strides, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine); - while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, + weight_md, *bias_md, out_md, strides, padding, padding, + mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else { mkldnn::memory::dims dilates{0, 0}; dilates[0] = param.conv_param.dilate[0] - 1; dilates[1] = param.conv_param.dilate[1] - 1; - if (bias == nullptr) { - mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, dilates, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine); - while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; - } else { - auto bias_md = GetMemDesc(*bias); - mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, - data_md, weight_md, bias_md, out_md, strides, - dilates, padding, padding, + if (bias_md == nullptr) { + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine); - while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + } else { + mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, + weight_md, *bias_md, out_md, strides, dilates, padding, + padding, mkldnn::padding_kind::zero); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + } + } +} + +mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, + const bool is_train, const NDArray &data, + const NDArray &weights, + const NDArray *bias, + const NDArray &output) { + CHECK_GE(param.conv_param.stride.ndim(), 2U); + CHECK_GE(param.conv_param.pad.ndim(), 2U); + CHECK_GE(param.conv_param.dilate.ndim(), 2U); + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.conv_param.num_group, param.mkldnn_param.quantized); + auto out_md = GetMemDesc(output); + auto bias_md = + bias ? (param.mkldnn_param.quantized ? GetMemDesc(*bias, mshadow::kInt32) : GetMemDesc(*bias)) + : mkldnn::memory::desc{ + {}, mkldnn::memory::data_type::data_undef, mkldnn::memory::format::any}; + auto bias_md_ptr = bias ? &bias_md : nullptr; + try { + auto conv_pd = GetConvFwdImpl(param, is_train, data_md, weight_md, bias_md_ptr, out_md); + while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || + conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || + (!param.mkldnn_param.quantized && + conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights))) { + CHECK(conv_pd.next_impl()) << "No convolution implementation for this request."; } + return conv_pd; + } catch (mkldnn::error &e) { + if (e.status == mkldnn_unimplemented && param.mkldnn_param.quantized) { + LOG(ERROR) << "AVX512-BW support or Intel(R) MKL dependency is " + "required for int8 convolution"; + } else { + LOG(ERROR) << e.message; + } + throw; } } @@ -270,48 +269,31 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( } } -void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, - const mkldnn::memory &weight, - const mkldnn::memory *bias, - const mkldnn::memory &output) { - if (this->data_ == nullptr) - this->data_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data_->set_data_handle(data.get_data_handle()); - - if (this->weight_ == nullptr) - this->weight_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight_->set_data_handle(weight.get_data_handle()); - - if (this->out_ == nullptr) - this->out_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (bias != nullptr) { - if (this->bias_ == nullptr) - this->bias_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.bias_primitive_desc(), bias->get_data_handle())); - else - this->bias_->set_data_handle(bias->get_data_handle()); - if (this->fwd_ == nullptr) - this->fwd_ = std::shared_ptr( - new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data_), - mkldnn::primitive::at(*this->weight_), - mkldnn::primitive::at(*this->bias_), - *this->out_)); - } else if (this->fwd_ == nullptr) { - this->fwd_ = std::shared_ptr( - new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data_), - mkldnn::primitive::at(*this->weight_), - *this->out_)); +MKLDNNConvForward::MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output) + : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) { + data_ = std::make_shared(fwd_pd.src_primitive_desc(), nullptr); + weight_ = std::make_shared(fwd_pd.weights_primitive_desc(), nullptr); + out_ = std::make_shared(fwd_pd.dst_primitive_desc(), nullptr); + if (bias) { + bias_ = std::make_shared(fwd_pd.bias_primitive_desc(), nullptr); + fwd_ = std::make_shared(fwd_pd, *this->data_, *this->weight_, + *this->bias_, *this->out_); + } else { + fwd_ = std::make_shared(fwd_pd, *this->data_, *this->weight_, + *this->out_); } } +void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory *bias, const mkldnn::memory &output) { + data_->set_data_handle(data.get_data_handle()); + weight_->set_data_handle(weight.get_data_handle()); + out_->set_data_handle(output.get_data_handle()); + if (bias != nullptr) bias_->set_data_handle(bias->get_data_handle()); +} + MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, const bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h index 89c3c199488a..3c65172c6116 100644 --- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h @@ -74,6 +74,9 @@ static void MKLDNNDequantizeComputeKer(const std::vector &inputs, i_dims[i] = static_cast(in_buffer.shape()[i]); } mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); + if (i_fmt == mkldnn::memory::format::nhwc) { + i_fmt = mkldnn::memory::format::nchw; + } auto o_desc = mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum::type, i_fmt); diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h new file mode 100644 index 000000000000..8e115fab7170 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_quantize_v2-inl.h + * \brief + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_ +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../quantize_v2-inl.h" + +namespace mxnet { +namespace op { + +template +static void MKLDNNQuantizeComputeKer(const std::vector& inputs, + const std::vector& outputs, + const QuantizeV2Param& param, + const std::vector& req) { + using namespace mshadow; + using namespace mxnet_op; + using red::limits::MaxValue; + using red::limits::MinValue; + float real_range = 0.0; + float quantized_range = 0.0; + NDArray in_buffer = inputs[0]; + float data_min = red::limits::MaxValue(); + float data_max = red::limits::MinValue(); + + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + data_min = param.min_calib_range.value(); + data_max = param.max_calib_range.value(); + } else { + // no calib info + in_buffer = inputs[0].Reorder2Default(); + auto in_ptr = in_buffer.data().dptr(); +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) \ + reduction(min : data_min) reduction(max : data_max) + for (int64_t i = 0; i < in_buffer.shape().Size(); i++) { + if (in_ptr[i] > data_max) data_max = in_ptr[i]; + if (in_ptr[i] < data_min) data_min = in_ptr[i]; + } + } + auto out_type = GetOutputType(param); + if (out_type == mshadow::kUint8) { + real_range = MaxAbs(data_min, data_max); + quantized_range = MaxAbs(MaxValue(), MinValue()); + *outputs[1].data().dptr() = data_min; + *outputs[2].data().dptr() = data_max; + } else if (out_type == mshadow::kInt8) { + real_range = MaxAbs(data_min, data_max); + quantized_range = MinAbs(MaxValue(), MinValue()); + *outputs[1].data().dptr() = -real_range; + *outputs[2].data().dptr() = real_range; + } else { + LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; + } + float scale = quantized_range / real_range; + + primitive_attr attr; + const int mask = 0; + std::vector scales = {scale}; + attr.set_output_scales(mask, scales); + attr.set_int_output_round_mode(round_nearest); + mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); + + if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default(); + auto i_mem = in_buffer.GetMKLDNNData(); + auto i_mpd = i_mem->get_primitive_desc(); + auto i_desc = i_mpd.desc(); + mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); + if (i_fmt == mkldnn::memory::format::nchw || + i_fmt == mkldnn::memory::format::nChw8c || + i_fmt == mkldnn_nChw16c) { + i_fmt = mkldnn::memory::format::nhwc; + } + size_t i_ndim = in_buffer.shape().ndim(); + mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); + for (size_t i = 0; i < i_ndim; i++) { + i_dims[i] = static_cast(in_buffer.shape()[i]); + } + auto o_desc = + mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum::type, i_fmt); + auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); + auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); + auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]); + MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second)); + CommitOutput(outputs[0], o_mem); + MKLDNNStream::Get()->Submit(); +} + +static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const QuantizeV2Param& param = nnvm::get(attrs.parsed); + auto out_type = GetOutputType(param); + if (out_type == mshadow::kUint8) { + MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + } else if (out_type == mshadow::kInt8) { + MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + } else { + LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_ diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h index ee7112205892..efc841009706 100644 --- a/src/operator/quantization/quantization_utils.h +++ b/src/operator/quantization/quantization_utils.h @@ -27,6 +27,7 @@ #include #include #include "../mxnet_op.h" +#include "../tensor/broadcast_reduce_op.h" namespace mxnet { namespace op { @@ -171,6 +172,20 @@ struct QuantizationRangeForMultiplicationStruct { } }; +template +inline size_t ConfigReduce(mshadow::Stream* s, + const TShape& data_shape, + const TShape& out_shape, + TShape* src_shape, + TShape* dst_shape) { + BroadcastReduceShapeCompact(data_shape, out_shape, src_shape, dst_shape); + constexpr int NDim = 2; + CHECK_EQ(src_shape->ndim(), NDim); + CHECK_EQ(dst_shape->ndim(), NDim); + + return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZATION_UTILS_H_ diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index fcd0fb4218be..8706087f9879 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -26,6 +26,7 @@ #include #include #include +#include "quantize_v2-inl.h" namespace mxnet { namespace op { @@ -63,12 +64,12 @@ NodePtr InsertNode(std::string op_name, } std::vector OfflineParams(std::vector&& outputs, - std::unordered_set&& offline_params) { + const std::unordered_set&& offline_params) { std::string node_suffixs[3] = {"", "_min", "_max"}; std::unordered_map mirror_map; nnvm::NodeEntryMap entry_var; auto need_offline = [&](NodePtr n) { - return (n->op() == Op::Get("_contrib_quantize")) && + return (n->op() == Op::Get("_contrib_quantize_v2")) && n->inputs[0].node->is_variable() && offline_params.count(n->inputs[0].node->attrs.name); }; @@ -121,10 +122,9 @@ Graph QuantizeGraph(Graph &&src) { static const auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); - auto offline_params = src.GetAttr>("offline_params"); - auto excluded_nodes = src.GetAttr>("excluded_nodes"); - auto quantized_dtype = src.GetAttr("quantized_dtype"); - auto calib_quantize = src.GetAttr("calib_quantize"); + const auto offline_params = src.GetAttr>("offline_params"); + const auto excluded_nodes = src.GetAttr>("excluded_nodes"); + const auto quantized_dtype = src.GetAttr("quantized_dtype"); // mirror_map stores the mapping from the currently visited graph to the newly created quantized // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key @@ -174,24 +174,10 @@ Graph QuantizeGraph(Graph &&src) { } } - NodePtr quantize_node = InsertNode("_contrib_quantize", + NodePtr quantize_node = InsertNode("_contrib_quantize_v2", e.node->attrs.name + suffix + "_quantize", new_node, mirror_entry); quantize_node->attrs.dict["out_type"] = quantized_dtype; quantize_node->op()->attr_parser(&(quantize_node->attrs)); - if (calib_quantize) { - NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + suffix + "_min"); - quantize_node->inputs.emplace_back(NodeEntry{min_var, 0, 0}); - NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + suffix + "_max"); - quantize_node->inputs.emplace_back(NodeEntry{max_var, 0, 0}); - } else { - NodePtr min_node = InsertNode("min", - e.node->attrs.name + suffix + "_min", quantize_node, mirror_entry); - min_node->op()->attr_parser(&(min_node->attrs)); - - NodePtr max_node = InsertNode("max", - e.node->attrs.name + suffix + "_max", quantize_node, mirror_entry); - max_node->op()->attr_parser(&(max_node->attrs)); - } mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version}; } } else if (mirror_node->op() == Op::Get("_contrib_dequantize")) { @@ -269,43 +255,35 @@ Graph QuantizeGraph(Graph &&src) { // the new_node. *new_node = *node; new_node->inputs.clear(); - if (node->is_variable() && node->attrs.name == "data") { - // Insert identity for data to collect calib for it. - NodePtr identity_node = - CreateNode("identity", new_node->attrs.name + "_id"); - identity_node->inputs.emplace_back(NodeEntry{new_node, 0, 0}); - new_node = identity_node; - } else { - for (const auto& e : node->inputs) { - NodePtr mirror_node = mirror_map.at(e.node.get()); - NodeEntry mirror_entry = NodeEntry{ - mirror_node, e.index, e.version}; - // if input node is quantized operator, add dequantize node - if (NeedQuantize(e.node, excluded_nodes) && - (mirror_node->op() != Op::Get("_contrib_dequantize"))) { - // here we calculate the output number (exclude min/max, in order to - // calculate min/max index from mirror node) based on assumption that - // there is only 1min and 1max output from mirror node (which is - // currently true) - size_t num_outputs = mirror_node->num_outputs() - 2; - uint32_t min_index = num_outputs + 2 * e.index; - uint32_t max_index = num_outputs + 2 * e.index + 1; - NodePtr dequantize_node = CreateNode("_contrib_dequantize", - e.node->attrs.name + "_dequantize"); - dequantize_node->inputs.emplace_back(mirror_entry); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); - dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); + for (const auto& e : node->inputs) { + NodePtr mirror_node = mirror_map.at(e.node.get()); + NodeEntry mirror_entry = NodeEntry{ + mirror_node, e.index, e.version}; + // if input node is quantized operator, add dequantize node + if (NeedQuantize(e.node, excluded_nodes) && + (mirror_node->op() != Op::Get("_contrib_dequantize"))) { + // here we calculate the output number (exclude min/max, in order to + // calculate min/max index from mirror node) based on assumption that + // there is only 1min and 1max output from mirror node (which is + // currently true) + size_t num_outputs = mirror_node->num_outputs() - 2; + uint32_t min_index = num_outputs + 2 * e.index; + uint32_t max_index = num_outputs + 2 * e.index + 1; + NodePtr dequantize_node = CreateNode("_contrib_dequantize", + e.node->attrs.name + "_dequantize"); + dequantize_node->inputs.emplace_back(mirror_entry); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); - new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); - mirror_map[e.node.get()] = std::move(dequantize_node); - } else if (mirror_entry_map.count(e)) { - new_node->inputs.emplace_back( - NodeEntry{mirror_entry_map[e].node->inputs[0].node, e.index, e.version}); - } else { - new_node->inputs.emplace_back( - NodeEntry{mirror_node, e.index, e.version}); - } + new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); + mirror_map[e.node.get()] = std::move(dequantize_node); + } else if (mirror_entry_map.count(e)) { + new_node->inputs.emplace_back( + NodeEntry{mirror_entry_map[e].node->inputs[0].node, e.index, e.version}); + } else { + new_node->inputs.emplace_back( + NodeEntry{mirror_node, e.index, e.version}); } } } @@ -361,7 +339,11 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) { && need_requantize_map[quantized_op_node->op()](quantized_op_node->attrs)) << quantized_op_node->attrs.name << " op must register FNeedRequantize attr" " and the attr func should return true"; - std::string out_data_name = quantized_op_node->attrs.name + "_"; + const std::string prefix = "quantized_"; + CHECK(std::equal(prefix.begin(), prefix.end(), quantized_op_node->attrs.name.begin())) + << "an quantized op should start with `quantized_`"; + + std::string out_data_name = quantized_op_node->attrs.name.substr(prefix.size()) + "_"; auto list_output_names_func = flist_outputs.get(quantized_op_node->op(), nullptr); // Here it's assumed that the quantized_op node only produces three outputs: // out_data, min_range, and max_range. So we want to get the pre-calculated min_calib_range @@ -381,6 +363,34 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) { node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); node->op()->attr_parser(&(node->attrs)); } + } else if (node->op() == Op::Get("_contrib_quantize_v2")) { + NodePtr float_op_node = node->inputs[0].node; + auto float_op_idx = node->inputs[0].index; + std::string out_data_name = float_op_node->attrs.name; + if (float_op_node->op()) { + auto list_output_names_func = flist_outputs.get(float_op_node->op(), nullptr); + // We want to get the pre-calculated min_range and max_range from the calibration table for + // out_data. Here we create the output data name same as its constructed in + // GraphExecutor::ExecuteMonCallback. + if (list_output_names_func != nullptr) { + std::vector names = list_output_names_func(float_op_node->attrs); + out_data_name += "_" + names[float_op_idx]; + } else { + out_data_name += "_" + std::to_string(float_op_idx); + } + } + const auto calib_table_iter = calib_table.find(out_data_name); + if (calib_table_iter != calib_table.end()) { + node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); + node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); + node->op()->attr_parser(&(node->attrs)); + const QuantizeV2Param& param = nnvm::get(node->attrs.parsed); + if (param.out_type == QuantizeV2Param::OutType::kUint8 && + param.min_calib_range.value() < 0.0f) { + LOG(WARNING) << "Calibration statistics indicates that node `" << node->attrs.name + << "` has negative input, consider use `auto` or `int8` as out_type"; + } + } } }); return g; diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h new file mode 100644 index 000000000000..9ba2c3a5437c --- /dev/null +++ b/src/operator/quantization/quantize_v2-inl.h @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantize_v2-inl.h + * \brief implementation of quantize operation + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_V2_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_V2_INL_H_ + +#include +#include +#include +#include "../elemwise_op_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "./quantization_utils.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct QuantizeV2Param : public dmlc::Parameter { + enum OutType { kAuto = 0, kInt8, kUint8 }; + int out_type; + dmlc::optional min_calib_range; + dmlc::optional max_calib_range; + DMLC_DECLARE_PARAMETER(QuantizeV2Param) { + DMLC_DECLARE_FIELD(out_type) + .add_enum("auto", kAuto) + .add_enum("int8", kInt8) + .add_enum("uint8", kUint8) + .set_default(kUint8) + .describe("Output data type. `auto` can be specified to automatically determine output type " + "according to min_calib_range."); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32. If present, it will be used to " + "quantize the fp32 data into int8 or uint8."); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32. If present, it will be used to " + "quantize the fp32 data into int8 or uint8."); + } +}; + +static mshadow::TypeFlag GetOutputType(const QuantizeV2Param ¶m) { + auto out_type = mshadow::kInt8; + if (param.out_type == QuantizeV2Param::OutType::kAuto) { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + if (param.min_calib_range.value() >= 0.0) { + out_type = mshadow::kUint8; + } else { + out_type = mshadow::kInt8; + } + } + } else if (param.out_type == QuantizeV2Param::OutType::kInt8) { + out_type = mshadow::kInt8; + } else if (param.out_type == QuantizeV2Param::OutType::kUint8) { + out_type = mshadow::kUint8; + } else { + LOG(FATAL) << "Unsupported quantize output type."; + } + return out_type; +} + +// quantize float to uint8_t +struct quantize_v2_unsigned { + template + MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, float *omax_range, + const SrcDType *in, const float *imin_range, + const float *imax_range, const double min_limit, + const double max_limit) { + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + const float scale = (max_limit - min_limit) / (*imax_range - *imin_range); + out[i] = static_cast((in[i] - *imin_range) * scale + 0.5); + *omin_range = *imin_range; + *omax_range = *imax_range; + } +}; + +// keep zero-center +struct quantize_v2_zero_centered { + template + MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, float *omax_range, + const SrcDType *in, const float *imin_range, + const float *imax_range, const float quantized_range) { + float real_range = MaxAbs(*imin_range, *imax_range); + float scale = quantized_range / real_range; + SrcDType x = in[i]; + out[i] = static_cast(Sign(x) * Min(Abs(x) * scale + 0.5f, quantized_range)); + *omin_range = -real_range; + *omax_range = real_range; + } +}; + +template +void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + typedef float SrcDType; + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + Stream *s = ctx.get_stream(); + + const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); + auto out_type = GetOutputType(param); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + auto in_min = param.min_calib_range.value(); + auto in_max = param.max_calib_range.value(); + if (out_type == mshadow::kUint8) { + Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), + outputs[1].dptr(), outputs[2].dptr(), + inputs[0].dptr(), &in_min, &in_max, + MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), &in_min, &in_max, + MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } + } else { // model is not calibrated + TShape src_shape, dst_shape; + const size_t actual_float_size = sizeof(float); + const size_t actual_quantized_size = sizeof(SrcDType); + const size_t temp_reduce_size = + ConfigReduce(s, inputs[0].shape_, TShape({1}), &src_shape, &dst_shape); + Tensor temp_space = ctx.requested[0].get_space_typed( + Shape1(2 * actual_float_size + 2 * actual_quantized_size + temp_reduce_size), s); + Tensor actual_min_float(reinterpret_cast(temp_space.dptr_), Shape1(1), + s); + Tensor actual_max_float(reinterpret_cast(temp_space.dptr_) + 1, + Shape1(1), s); + + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob actual_min_quantized(reinterpret_cast(temp_space.dptr_ + 8), Shape1(1), + xpu::kDevMask, dev_id); + TBlob actual_max_quantized(reinterpret_cast(temp_space.dptr_ + 8) + 1, Shape1(1), + xpu::kDevMask, dev_id); + Tensor workspace( + temp_space.dptr_ + 2 * actual_float_size + 2 * actual_quantized_size, + Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, actual_min_quantized.reshape(dst_shape), kWriteTo, workspace, + inputs[0].reshape(src_shape)); + Kernel::Launch(s, 1, actual_min_float.dptr_, + actual_min_quantized.dptr(), + inputs[1].dptr(), inputs[2].dptr()); + + broadcast::Reduce( + s, actual_max_quantized.reshape(dst_shape), kWriteTo, workspace, + inputs[0].reshape(src_shape)); + Kernel::Launch(s, 1, actual_max_float.dptr_, + actual_max_quantized.dptr(), + inputs[1].dptr(), inputs[2].dptr()); + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), actual_min_float.dptr_, + actual_max_float.dptr_, MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), actual_min_float.dptr_, + actual_max_float.dptr_, MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } + } +} + +static inline bool QuantizeV2Shape(const nnvm::NodeAttrs &attrs, std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1}); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1}); + return !shape_is_none(out_attrs->at(0)); +} + +static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); + TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32); + auto out_type = GetOutputType(param); + if (out_type == mshadow::kUint8) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); + } else if (out_type == mshadow::kInt8) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8); + } else { + LOG(FATAL) << "Unsupported out_type."; + } + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); + return (*in_attrs)[0] != -1; +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_V2_INL_H_ diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc new file mode 100644 index 000000000000..afa341f8a780 --- /dev/null +++ b/src/operator/quantization/quantize_v2.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantize.cc + * \brief + */ +#include "./quantize_v2-inl.h" +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn/mkldnn_quantize_v2-inl.h" +#endif + +namespace mxnet { +namespace op { +DMLC_REGISTER_PARAMETER(QuantizeV2Param); + +static bool QuantizeV2StorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + (*out_attrs)[0] = kDefaultStorage; + (*out_attrs)[1] = kDefaultStorage; + (*out_attrs)[2] = kDefaultStorage; + return true; +} + +NNVM_REGISTER_OP(_contrib_quantize_v2) +.describe(R"code(Quantize a input tensor from float to `out_type`, +with user-specified `min_calib_range` and `max_calib_range` or the input range collected at runtime. + +Output `min_range` and `max_range` are scalar floats that specify the range for the input data. + +When out_type is `uint8`, the output is calculated using the following equation: + +`out[i] = (in[i] - min_range) * range(OUTPUT_TYPE) / (max_range - min_range) + 0.5`, + +where `range(T) = numeric_limits::max() - numeric_limits::min()`. + +When out_type is `int8`, the output is calculate using the following equation +by keep zero centered for the quantized value: + +`out[i] = sign(in[i]) * min(abs(in[i] * scale + 0.5f, quantized_range)`, + +where +`quantized_range = MinAbs(max(int8), min(int8))` and +`scale = quantized_range / MaxAbs(min_range, max_range).` + +When out_type is `auto`, the output type is automatically determined by min_calib_range if presented. +If min_calib_range < 0.0f, the output type will be int8, otherwise will be uint8. +If min_calib_range isn't presented, the output type will be int8. + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(3) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", QuantizeV2Shape) +.set_attr("FInferType", QuantizeV2Type) +.set_attr("FInferStorageType", QuantizeV2StorageType) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", MKLDNNQuantizeV2Compute) +#endif +.set_attr("FCompute", QuantizeV2Compute) +.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") +.add_arguments(QuantizeV2Param::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantize_v2.cu b/src/operator/quantization/quantize_v2.cu new file mode 100644 index 000000000000..ab0cf9c5ad0e --- /dev/null +++ b/src/operator/quantization/quantize_v2.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file quantize_v2.cu + * \brief + */ +#include "./quantize_v2-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_quantize_v2) +.set_attr("FCompute", QuantizeV2Compute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h index e07a149f8a6b..148453e63257 100644 --- a/src/operator/quantization/requantize-inl.h +++ b/src/operator/quantization/requantize-inl.h @@ -87,20 +87,6 @@ struct RequantizeKernel { } }; -template -inline size_t ConfigReduce(mshadow::Stream* s, - const TShape& data_shape, - const TShape& out_shape, - TShape* src_shape, - TShape* dst_shape) { - BroadcastReduceShapeCompact(data_shape, out_shape, src_shape, dst_shape); - constexpr int NDim = 2; - CHECK_EQ(src_shape->ndim(), NDim); - CHECK_EQ(dst_shape->ndim(), NDim); - - return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape); -} - template void RequantizeForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index dfa98d1f5ee9..2099d1b1ec24 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -43,10 +43,10 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, true, beta.dtype()); const DType *weight_ptr = weight->data().dptr(); const DType *bias_ptr = no_bias ? nullptr : bias->data().dptr(); - const DType *gamma_ptr = gamma.Reorder2Default().data().dptr(); - const DType *beta_ptr = beta.Reorder2Default().data().dptr(); - const DType *mean_ptr = mean.Reorder2Default().data().dptr(); - const DType *var_ptr = variance.Reorder2Default().data().dptr(); + const DType *gamma_ptr = gamma.data().dptr(); + const DType *beta_ptr = beta.data().dptr(); + const DType *mean_ptr = mean.data().dptr(); + const DType *var_ptr = variance.data().dptr(); DType *update_weight_ptr = update_weight.data().dptr(); DType *update_bias_ptr = update_bias.data().dptr(); size_t channel = gamma.shape()[0]; @@ -77,23 +77,17 @@ static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { } template -static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, - bool has_bias, float data_scale, - bool weight_channelwise_scale, - std::vector *weight_scales) { +static std::vector GetWeightScales(const NDArray &weight, bool weight_channelwise_scale) { using red::limits::MaxValue; using red::limits::MinValue; - const DType *weight_ptr = weight->data().dptr(); - NDArray quantized_weight = NDArray(weight->storage_type(), weight->shape(), - weight->ctx(), true, mshadow::kInt8); - int8_t *quan_weight_ptr = quantized_weight.data().dptr(); - size_t channel = weight->shape()[0]; + std::vector weight_scales; + const DType *weight_ptr = weight.data().dptr(); + size_t channel = weight.shape()[0]; // TODO(Zhennan): Handle the case weight is not in dims 4. - size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; + size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3]; std::vector weight_c_min(channel, MaxValue()); std::vector weight_c_max(channel, MinValue()); -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { const DType *p1 = weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { @@ -105,16 +99,10 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, } if (weight_channelwise_scale) { - weight_scales->resize(channel); -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + weight_scales.resize(channel); for (int c = 0; c < static_cast(channel); ++c) { DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); - weight_scales->at(c) = kInt8Range / weight_range; - const DType *fp_ptr = weight_ptr + c * offset; - int8_t *quan_ptr = quan_weight_ptr + c * offset; - for (size_t k = 0; k < offset; ++k) { - quan_ptr[k] = std::round(weight_scales->at(c) * fp_ptr[k]); - } + weight_scales[c] = kInt8Range / weight_range; } } else { DType total_min = weight_c_min[0]; @@ -123,74 +111,73 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, if (total_min > weight_c_min[c]) total_min = weight_c_min[c]; if (total_max < weight_c_max[c]) total_max = weight_c_max[c]; } - weight_scales->resize(1); + weight_scales.resize(1); DType weight_range = MaxAbs(total_min, total_max); - weight_scales->at(0) = kInt8Range / weight_range; -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int c = 0; c < static_cast(channel); ++c) { - const DType *fp_ptr = weight_ptr + c * offset; - int8_t *quan_ptr = quan_weight_ptr + c * offset; - for (size_t k = 0; k < offset; ++k) { - quan_ptr[k] = std::round(weight_scales->at(0) * fp_ptr[k]); - } - } - } - - *weight = quantized_weight; - if (has_bias) { - const DType *bias_ptr = bias->data().dptr(); - NDArray quantized_bias = NDArray(bias->storage_type(), bias->shape(), - bias->ctx(), true, mshadow::kInt32); - int32_t *quan_bias_ptr = quantized_bias.data().dptr(); - for (size_t c = 0; c < channel; ++c) { - auto weight_scale = - weight_channelwise_scale ? weight_scales->at(c) : weight_scales->at(0); - float bias_scale = weight_scale * data_scale; - quan_bias_ptr[c] = std::round(bias_scale * bias_ptr[c]); - } - *bias = quantized_bias; + weight_scales[0] = kInt8Range / weight_range; } + return weight_scales; } -static void ConvFusionFallBackCompute() { - LOG(FATAL) << "Don't know how to do ConvFusionFallBackCompute!"; -} - -static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam &full_param, - const OpContext &ctx, - MKLDNNConvForward *fwd, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - if (SupportMKLDNNConv(full_param.conv_param, inputs[0])) { - MKLDNNConvolutionForwardFullFeature(full_param, ctx, fwd, inputs, req, outputs); - return; +static void ConvertWeightBias2MKLDNN(const MKLDNNConvFullParam ¶m, + mkldnn::convolution_forward::primitive_desc fwd_pd, + NDArray *weight, NDArray *bias, bool has_bias, + float data_scale, const std::vector &weight_scales) { + MKLDNNStream *stream = MKLDNNStream::Get(); + const auto new_weight = NDArray(fwd_pd.weights_primitive_desc()); + const auto conv_weights_memory = new_weight.GetMKLDNNData(); + primitive_attr weight_attr; + if (weight_scales.size()) { + const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1; + weight_attr.set_int_output_round_mode(round_mode::round_nearest); + weight_attr.set_output_scales(weight_mask, weight_scales); + } + auto default_weights_memory = GetWeights(*weight, param.conv_param.num_group); + if (default_weights_memory == nullptr) default_weights_memory = weight->GetMKLDNNData(); + const auto weight_reorder_pd = + mkldnn::reorder::primitive_desc(default_weights_memory->get_primitive_desc(), + conv_weights_memory->get_primitive_desc(), weight_attr); + stream->RegisterPrim( + mkldnn::reorder(weight_reorder_pd, *default_weights_memory, *conv_weights_memory)); + + NDArray new_bias; + if (has_bias && data_scale) { + std::vector bias_scales(weight_scales.size()); + for (size_t c = 0; c < weight_scales.size(); ++c) { + bias_scales[c] = weight_scales[c] * data_scale; + } + new_bias = NDArray(fwd_pd.bias_primitive_desc()); + const auto conv_bias_memory = new_bias.GetMKLDNNData(); + const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1; + primitive_attr bias_attr; + bias_attr.set_int_output_round_mode(round_mode::round_nearest); + bias_attr.set_output_scales(bias_mask, bias_scales); + auto bias_weights_memory = bias->GetMKLDNNData(); + auto bias_reorder_pd = + mkldnn::reorder::primitive_desc(bias_weights_memory->get_primitive_desc(), + conv_bias_memory->get_primitive_desc(), bias_attr); + stream->RegisterPrim( + mkldnn::reorder(bias_reorder_pd, *bias_weights_memory, *conv_bias_memory)); } - ConvFusionFallBackCompute(); + stream->Submit(); + *weight = new_weight; + if (has_bias && data_scale) *bias = new_bias; } class SgMKLDNNConvOperator { public: explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs) - : initalized_(false), - subgraph_sym_(*attrs.subgraphs[0]), - param_(nnvm::get(attrs.parsed)), - inplace_(false) {} + : subgraph_sym_(*attrs.subgraphs[0]), + param_(nnvm::get(attrs.parsed)) {} void Forward(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs); - void Backward(const OpContext &ctx, const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - LOG(FATAL) << "Not implemented: subgraph mkldnn Conv only supports " - "inference computation."; - } - private: - bool initalized_; + bool initalized_{false}; + bool inplace_{false}; + bool post_requantize_{false}; nnvm::Symbol subgraph_sym_; MKLDNNConvFusionParam param_; std::shared_ptr fwd_; @@ -200,10 +187,12 @@ class SgMKLDNNConvOperator { float cached_data_max_; float cached_sum_min_; float cached_sum_max_; + float cached_output_min_; + float cached_output_max_; size_t weight_ver_; size_t bias_ver_; + float data_scale_{0.0f}; std::vector weight_scales_; - bool inplace_; }; void SgMKLDNNConvOperator::Forward(const OpContext &ctx, @@ -239,10 +228,6 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized) ? inputs[idx++].data().dptr()[0] : 0.0; - float *out_min_ptr = - mkldnn_param.quantized ? outputs[kMin].data().dptr() : nullptr; - float *out_max_ptr = - mkldnn_param.quantized ? outputs[kMax].data().dptr() : nullptr; CHECK_EQ(input_size, idx); bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias; NDArray data = inputs[in_data]; @@ -251,18 +236,22 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, // Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed. if (mkldnn_param.with_sum) { if (!initalized_) { - auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); - auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option, // which make check (req[kOut] == kWriteInplace) useless. + auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); + auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) { inplace_ = true; } } if (!inplace_) { auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); - const_cast(outputs[kOut]).CopyFrom(*in_mkl_mem); - output = NDArray(outputs[kOut].GetMKLDNNData()); + auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); + mkldnn_mem_ptr tmp_mem( + new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(tmp_mem); + mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get()); + output = NDArray(tmp_mem); } } @@ -284,19 +273,6 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } } } - bool post_requantize = false; - if (mkldnn_param.quantized) { - if (mkldnn_param.min_calib_range.has_value() && - mkldnn_param.max_calib_range.has_value()) { - post_requantize = true; - mkldnn_param.weight_channelwise_scale = true; - *out_min_ptr = mkldnn_param.min_calib_range.value(); - *out_max_ptr = mkldnn_param.max_calib_range.value(); - } else { - mkldnn_param.weight_channelwise_scale = false; - } - } - if (!initalized_) { cached_data_min_ = data_min; cached_data_max_ = data_max; @@ -306,7 +282,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, cached_weight_ = inputs[in_weight].Reorder2Default(); weight_ver_ = inputs[in_weight].version(); if (!conv_param.no_bias) { - cached_bias_ = inputs[in_bias].Reorder2Default(); + cached_bias_ = inputs[in_bias]; bias_ver_ = inputs[in_bias].version(); } else { cached_bias_ = NDArray(); @@ -327,13 +303,23 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, // Quantize weight and bias. if (mkldnn_param.quantized) { CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8); + if (cached_data_min_ < 0.0f) { + CHECK_EQ(data.dtype(), mshadow::kInt8) + << "Expect int8 when data_min < 0.0, consider quantize model with int8."; + } + if (mkldnn_param.min_calib_range.has_value() && mkldnn_param.max_calib_range.has_value()) { + cached_output_min_ = mkldnn_param.min_calib_range.value(); + cached_output_max_ = mkldnn_param.max_calib_range.value(); + post_requantize_ = true; + mkldnn_param.weight_channelwise_scale = true; + } else { + mkldnn_param.weight_channelwise_scale = false; + } auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range; - float data_scale = data_range / MaxAbs(cached_data_min_, cached_data_max_); + data_scale_ = data_range / MaxAbs(cached_data_min_, cached_data_max_); MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { - QuantizeConvWeightBias(&cached_weight_, &cached_bias_, - has_bias, data_scale, - mkldnn_param.weight_channelwise_scale, - &weight_scales_); + weight_scales_ = + GetWeightScales(cached_weight_, mkldnn_param.weight_channelwise_scale); }); // Collect scale. size_t channel = cached_weight_.shape()[0]; @@ -341,29 +327,21 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, float out_range; float quantized_out_range; float output_scale; - if (cached_data_min_ < 0.0) { - // TODO(zhennan): Support int8 input when mkldnn supports. - LOG(FATAL) << "Can't handle negetive value for QuantizeData"; - } if (mkldnn_param.with_sum) { auto quantized_sum_range = cached_sum_min_ < 0 ? kInt8Range : kUint8Range; sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_); } - if (post_requantize) { - quantized_out_range = - IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range; - out_range = MaxAbs(*out_min_ptr, *out_max_ptr); + if (post_requantize_) { + quantized_out_range = IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range; + out_range = MaxAbs(cached_output_min_, cached_output_max_); output_scale = quantized_out_range / out_range; - full_conv_param.requantize_scales.resize(channel); - for (size_t c = 0; c < channel; c++) { - auto weight_scale = mkldnn_param.weight_channelwise_scale - ? weight_scales_[c] - : weight_scales_[0]; - full_conv_param.requantize_scales[c] = - output_scale / data_scale / weight_scale; + full_conv_param.requantize_scales.resize(mkldnn_param.weight_channelwise_scale ? channel + : 1); + for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) { + full_conv_param.requantize_scales[c] = output_scale / data_scale_ / weight_scales_[c]; } } else { - output_scale = data_scale * weight_scales_[0]; + output_scale = data_scale_ * weight_scales_[0]; full_conv_param.requantize_scales.resize(0); } if (mkldnn_param.with_sum) @@ -372,23 +350,44 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, fwd_.reset(new MKLDNNConvForward( full_conv_param, ctx.is_train, data, cached_weight_, has_bias ? &cached_bias_ : nullptr, output)); + ConvertWeightBias2MKLDNN(full_conv_param, fwd_->fwd_pd, &cached_weight_, &cached_bias_, + has_bias, data_scale_, weight_scales_); + fwd_->SetNewMem(*data.GetMKLDNNData(), *cached_weight_.GetMKLDNNData(), + has_bias ? cached_bias_.GetMKLDNNData() : nullptr, + *output.GetMKLDNNData()); + initalized_ = true; } - initalized_ = true; - std::vector new_inputs; - std::vector new_req; - if (has_bias) { - new_inputs = {data, cached_weight_, cached_bias_}; - new_req = {req[in_data], req[in_weight], req[in_bias]}; + + if (!mkldnn_param.quantized) { + auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_primitive_desc()); + mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->fwd_pd.dst_primitive_desc()); + fwd_->SetNewMem(*data_mem, *mem); + MKLDNNStream::Get()->RegisterPrim(fwd_->GetFwd()); + MKLDNNStream::Get()->Submit(); } else { - new_inputs = {data, cached_weight_}; - new_req = {req[in_data], req[in_weight]}; + std::vector new_inputs; + std::vector new_req; + if (has_bias) { + new_inputs = {data, cached_weight_, cached_bias_}; + new_req = {req[in_data], req[in_weight], req[in_bias]}; + } else { + new_inputs = {data, cached_weight_}; + new_req = {req[in_data], req[in_weight]}; + } + MKLDNNConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), new_inputs, new_req, + {output}); + } + if (post_requantize_) { + float *out_min_ptr = outputs[kMin].data().dptr(); + float *out_max_ptr = outputs[kMax].data().dptr(); + *out_min_ptr = cached_output_min_; + *out_max_ptr = cached_output_max_; } - ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd_.get(), new_inputs, - new_req, {output}); - if (mkldnn_param.with_sum) { auto out = const_cast(outputs[kOut]); - out.UpdateMKLDNNMemDesc(); + auto format = static_cast( + fwd_->fwd_pd.dst_primitive_desc().desc().data.format); + out.UpdateMKLDNNMemDesc(format); } } @@ -405,7 +404,7 @@ static uint32_t SgMKLDNNConvNumInputs(const NodeAttrs &attrs) { auto const ¶m = nnvm::get(attrs.parsed); auto num_input = DefaultSubgraphOpNumInputs(attrs); if (param.full_conv_param.mkldnn_param.quantized) - return num_input + 2 + param.full_conv_param.mkldnn_param.with_sum ? 2 : 0; + return num_input + 2 + (param.full_conv_param.mkldnn_param.with_sum ? 2 : 0); else return num_input; } @@ -425,6 +424,7 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { os << ")"; throw dmlc::ParamError(os.str()); } + CHECK_EQ(attrs->subgraphs.size(), 1); auto subgraph_sym = attrs->subgraphs[0]; DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) { if (node->is_variable()) return; @@ -442,10 +442,23 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { attrs->parsed = std::move(param_); } -static std::vector SgMKLDNNConvListInputNames( - const NodeAttrs &attrs) { +static std::vector SgMKLDNNConvListInputNames(const NodeAttrs &attrs) { auto const ¶m = nnvm::get(attrs.parsed); - std::vector input_names = DefaultSubgraphOpListInputs(attrs); + std::vector input_names; + input_names.emplace_back("data"); + input_names.emplace_back("weight"); + if (!param.full_conv_param.conv_param.no_bias) { + input_names.emplace_back("bias"); + } + if (param.full_conv_param.mkldnn_param.with_bn) { + input_names.emplace_back("gamma"); + input_names.emplace_back("beta"); + input_names.emplace_back("mean"); + input_names.emplace_back("var"); + } + if (param.full_conv_param.mkldnn_param.with_sum) { + input_names.emplace_back("sum"); + } if (param.full_conv_param.mkldnn_param.quantized) { input_names.emplace_back("data_min"); input_names.emplace_back("data_max"); @@ -454,6 +467,7 @@ static std::vector SgMKLDNNConvListInputNames( input_names.emplace_back("sum_max"); } } + CHECK_EQ(input_names.size(), SgMKLDNNConvNumInputs(attrs)); return input_names; } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index e5220f24d34d..adfc41bb120f 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -66,17 +66,21 @@ class SgMKLDNNConvSelector : public SubgraphSelector { } bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { - if (status == kFail || status == kSuccess || new_node.is_variable()) - return false; // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. if (matched_list.back() != &n) { - while (matched_list.back() != &n) { - matched_list.pop_back(); + if (std::find(matched_list.begin(), matched_list.end(), &n) != + matched_list.end()) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } } status = kSuccess; return false; } + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + // Use status machine to do selection. The status change is // kStart -> kBN -> kSum -> kSuccess switch (status) { @@ -99,12 +103,11 @@ class SgMKLDNNConvSelector : public SubgraphSelector { nnvm::get(new_node.attrs.parsed); if (param.act_type == activation::kReLU) { matched_list.push_back(&new_node); - // If we find conv+relu, then we can't match bn anymore. - if (status == kStart) status = kBN; - return true; - } else { + // If we find conv+relu, then we can't match anymore. + // TODO(zhennan): mkldnn only supports convolution + relu + sum in + // int8, not in fp32. So we disable this pattern at moment. status = kSuccess; - return false; + return true; } } status = kSuccess; @@ -117,7 +120,15 @@ class SgMKLDNNConvSelector : public SubgraphSelector { if (status == kFail) { return std::vector(0); } else { - return candidates; + std::vector ret; + for (auto i : matched_list) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; } } }; diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index be6feaeb94a6..313668cb56f9 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -35,14 +35,14 @@ DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] -def check_qsym_calibrated(qsym): +def check_qsym_calibrated(qsym, out_type): assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 for k, v in qsym.attr_dict().items(): if k.find('quantized_sg_mkldnn_conv') != -1: assert 'min_calib_range' in v assert 'max_calib_range' in v if k.find('_quantize') != -1: - assert v['out_type'] == 'uint8' + assert v['out_type'] == out_type def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape): mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) @@ -66,7 +66,7 @@ def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape): output.wait_to_read() return mod.get_outputs() -def check_quantize(sym, data_shape, check_conv=True): +def check_quantize(sym, data_shape, out_type, check_conv=True): fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') sym_sg = sym.get_backend_symbol("MKLDNN") @@ -99,15 +99,14 @@ def check_quantize(sym, data_shape, check_conv=True): aux_params=aux_params, ctx=mx.current_context(), excluded_sym_names=excluded_sym_names, - quantized_dtype='uint8', + quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, calib_layer=calib_layer, - calib_quantize_op=True, num_calib_examples=5) qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") if check_conv: - check_qsym_calibrated(qsym) + check_qsym_calibrated(qsym, out_type) quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) for i in range(len(ref_out)): assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) @@ -135,8 +134,9 @@ def check_fusion(sym, data_shape, attrs_op): for i in range(len(exe.outputs)): assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3) - # fp32 to uint8 - check_quantize(sym, data_shape) + # fp32 to int8 + for out_type in ('uint8', 'int8', 'auto'): + check_quantize(sym, data_shape, out_type) def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)): for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): @@ -475,12 +475,13 @@ def test_pos_conv_bn_sum_relu(): def test_pos_single_concat(): for data_shape in DATA_SHAPE: - net = single_concat(data_shape, 2, 1) - check_quantize(net, data_shape, False) - net = single_concat(data_shape, 4, 2) - check_quantize(net, data_shape, False) - net = single_concat(data_shape, 4, 3) - check_quantize(net, data_shape, False) + for out_type in ('uint8', 'int8', 'auto'): + net = single_concat(data_shape, 2, 1) + check_quantize(net, data_shape, out_type, False) + net = single_concat(data_shape, 4, 2) + check_quantize(net, data_shape, out_type, False) + net = single_concat(data_shape, 4, 3) + check_quantize(net, data_shape, out_type, False) @with_seed() def test_neg_conv_bn(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 09157396f839..b25c726cbcc1 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6731,7 +6731,7 @@ def get_output_names_callback(name, arr): output_names.append(py_str(name)) op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null') - op_exe.set_monitor_callback(get_output_names_callback) + op_exe.set_monitor_callback(get_output_names_callback, monitor_all=False) op_exe.forward() for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name @@ -6769,6 +6769,51 @@ def get_output_names_callback(name, arr): name='pooling') check_name(us_sym, ['pooling_output']) +def test_op_all_names_monitor(): + def check_name(op_sym, expected_names): + output_names = [] + + def get_output_names_callback(name, arr): + output_names.append(py_str(name)) + + op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null') + op_exe.set_monitor_callback(get_output_names_callback, monitor_all=True) + op_exe.forward() + for output_name, expected_name in zip(output_names, expected_names): + assert output_name == expected_name + + data = mx.sym.Variable('data', shape=(10, 3, 10, 10)) + conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, name='conv') + check_name(conv_sym, ['data', 'conv_data', 'conv_weight', 'conv_weight', 'conv_bias', 'conv_bias', 'conv_output']) + + deconv_sym = mx.sym.Deconvolution(data, kernel=(2, 2), num_filter=1, name='deconv') + check_name(deconv_sym, ['data', 'deconv_data', 'deconv_weight', 'deconv_weight', 'deconv_output']) + + fc_sym = mx.sym.FullyConnected(data, num_hidden=10, name='fc') + check_name(fc_sym, ['data', 'fc_data', 'fc_weight', 'fc_weight', 'fc_bias', 'fc_bias', 'fc_output']) + + lrn_sym = mx.sym.LRN(data, nsize=1, name='lrn') + check_name(lrn_sym, ['data', 'lrn_data', 'lrn_output', 'lrn_tmp_norm']) + + act_sym = mx.sym.Activation(data, act_type='relu', name='act') + check_name(act_sym, ['data', 'act_input0', 'act_output']) + + cc_sym = mx.sym.concat(data, data, dim=0, name='concat') + check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 'concat_output']) + + sm_sym = mx.sym.softmax(data, name='softmax') + check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output']) + + sa_sym = mx.sym.SoftmaxActivation(data, name='softmax') + check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output']) + + us_sym = mx.sym.UpSampling(data, scale=2, sample_type='nearest', + name='upsampling') + check_name(us_sym, ['data', 'upsampling_arg0', 'upsampling_output']) + + us_sym = mx.sym.Pooling(data, kernel=(2, 2), pool_type='avg', + name='pooling') + check_name(us_sym, ['data', 'pooling_data', 'pooling_output']) @with_seed() def test_activation(): From c642d338b96e526b45732d4d46c5a5c14c976e54 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 20 Dec 2018 12:04:36 +0800 Subject: [PATCH 02/38] Fix cpp build --- cpp-package/include/mxnet-cpp/monitor.h | 3 ++- cpp-package/include/mxnet-cpp/monitor.hpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/monitor.h b/cpp-package/include/mxnet-cpp/monitor.h index c1494d0bd0a6..76e7ce836f18 100644 --- a/cpp-package/include/mxnet-cpp/monitor.h +++ b/cpp-package/include/mxnet-cpp/monitor.h @@ -70,8 +70,9 @@ class Monitor { /*! * \brief install callback to executor. Supports installing to multiple executors. * \param exe The executor to install to. + * \param monitor_all If true, monitor both input and output, otherwise monitor output only. */ - void install(Executor *exe); + void install(Executor *exe, bool monitor_all = false); /*! * \brief Start collecting stats for current batch. Call before calling forward. diff --git a/cpp-package/include/mxnet-cpp/monitor.hpp b/cpp-package/include/mxnet-cpp/monitor.hpp index d1e548fbc05a..bd7f1927e906 100644 --- a/cpp-package/include/mxnet-cpp/monitor.hpp +++ b/cpp-package/include/mxnet-cpp/monitor.hpp @@ -43,7 +43,7 @@ inline Monitor::Monitor(int interval, std::regex pattern, StatFunc stat_func) : interval(interval), pattern(pattern), stat_func(stat_func), step(0) { } -inline void Monitor::install(Executor *exe, bool monitor_all = false) { +inline void Monitor::install(Executor *exe, bool monitor_all) { MXExecutorSetMonitorCallback(exe->handle_, static_cast(&Monitor::executor_callback), this, monitor_all); From 57c24235198a188a005a0e38fb2b056cd70f10f6 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 20 Dec 2018 17:42:34 +0800 Subject: [PATCH 03/38] Fix build. --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 2 +- src/operator/quantization/quantize_graph_pass.cc | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index 8e115fab7170..cecf764518c5 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -58,7 +58,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, auto in_ptr = in_buffer.data().dptr(); #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) \ reduction(min : data_min) reduction(max : data_max) - for (int64_t i = 0; i < in_buffer.shape().Size(); i++) { + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { if (in_ptr[i] > data_max) data_max = in_ptr[i]; if (in_ptr[i] < data_min) data_min = in_ptr[i]; } diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 8706087f9879..676615cf0402 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -312,8 +312,7 @@ Graph QuantizeGraph(Graph &&src) { } } - if (!offline_params.empty()) outputs = - OfflineParams(std::move(outputs), std::move(offline_params)); + if (!offline_params.empty()) outputs = OfflineParams(std::move(outputs), offline_params); Graph ret; ret.outputs = std::move(outputs); From 923b9ee66d9ed37c7a266669d291c6e8056e9333 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 20 Dec 2018 18:10:18 +0800 Subject: [PATCH 04/38] Fix build --- src/operator/quantization/quantize_graph_pass.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 676615cf0402..af533978a6f5 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -64,7 +64,7 @@ NodePtr InsertNode(std::string op_name, } std::vector OfflineParams(std::vector&& outputs, - const std::unordered_set&& offline_params) { + const std::unordered_set& offline_params) { std::string node_suffixs[3] = {"", "_min", "_max"}; std::unordered_map mirror_map; nnvm::NodeEntryMap entry_var; @@ -89,7 +89,8 @@ std::vector OfflineParams(std::vector&& outputs, return outputs; } -inline bool NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes) { +inline bool NeedQuantize(const NodePtr node, + const std::unordered_set& excluded_nodes) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); const auto& op = node->op(); From 5ea756cddb163497cf24fec7b2a4ee1ec0a1ed1b Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 20 Dec 2018 20:35:53 +0800 Subject: [PATCH 05/38] Remove openmp min/max reduction for windows build --- .../quantization/mkldnn/mkldnn_quantize_v2-inl.h | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index cecf764518c5..806f67f01eb7 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -56,11 +56,18 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, // no calib info in_buffer = inputs[0].Reorder2Default(); auto in_ptr = in_buffer.data().dptr(); -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) \ - reduction(min : data_min) reduction(max : data_max) + auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + std::vector data_maxs(nthreads, data_max); + std::vector data_mins(nthreads, data_min); +#pragma omp parallel for num_threads(nthreads) for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { - if (in_ptr[i] > data_max) data_max = in_ptr[i]; - if (in_ptr[i] < data_min) data_min = in_ptr[i]; + int tid = omp_get_thread_num(); + if (in_ptr[i] > data_maxs[tid]) data_maxs[tid] = in_ptr[i]; + if (in_ptr[i] < data_mins[tid]) data_mins[tid] = in_ptr[i]; + } + for (index_t i = 0; i < nthreads; i++) { + if (data_maxs[i] > data_max) data_max = data_maxs[i]; + if (data_mins[i] < data_min) data_min = data_mins[i]; } } auto out_type = GetOutputType(param); From d1ba1ad432f79e76db2504def3574ee6d4394c39 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sun, 23 Dec 2018 20:42:13 +0800 Subject: [PATCH 06/38] Add mkldnn_OIhw4i16o4i_s8s8 support --- src/operator/nn/mkldnn/mkldnn_base.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 454d57f5ca90..215c9b0a5bd7 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -308,6 +308,7 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { case mkldnn_oIhw8i: case mkldnn_oIhw16i: case mkldnn_OIhw8i8o: + case mkldnn_hwio_s8s8: case mkldnn_OIhw16i16o: case mkldnn_OIhw4i16o4i: case mkldnn_OIhw4i16o4i_s8s8: From 7531e98db935a87615e05ab0c81435795a90e857 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sun, 23 Dec 2018 21:10:52 +0800 Subject: [PATCH 07/38] Add all s8s8 weight format --- src/operator/nn/mkldnn/mkldnn_base.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 215c9b0a5bd7..0cf44d3d78f8 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -331,9 +331,11 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { switch (desc.data.format) { case mkldnn_goihw: case mkldnn_hwigo: + case mkldnn_hwigo_s8s8: case mkldnn_gOIhw8i8o: case mkldnn_gOIhw16i16o: case mkldnn_gOIhw4i16o4i: + case mkldnn_gOIhw4i16o4i_s8s8: case mkldnn_gOIhw8i16o2i: case mkldnn_gOIhw8o16i2o: case mkldnn_gOIhw8o8i: From 3f75e82fd385d9c2667b27c27e489d6697384338 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 26 Dec 2018 12:21:46 +0800 Subject: [PATCH 08/38] Change ssd quantize script. --- example/ssd/quantization.py | 13 +++++-------- include/mxnet/executor.h | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py index 4ed28dd03c2f..8cdde894dc24 100644 --- a/example/ssd/quantization.py +++ b/example/ssd/quantization.py @@ -51,7 +51,7 @@ def save_params(fname, arg_params, aux_params, logger=None): parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--num-calib-batches', type=int, default=5, help='number of batches for calibration') - parser.add_argument('--exclude-first-conv', action='store_true', default=True, + parser.add_argument('--exclude-first-conv', action='store_true', default=False, help='excluding quantizing the first conv layer since the' ' number of channels is usually not a multiple of 4 in that layer' ' which does not satisfy the requirement of cuDNN') @@ -78,8 +78,8 @@ def save_params(fname, arg_params, aux_params, logger=None): ' thresholds. This mode is expected to produce the best inference accuracy of all three' ' kinds of quantized models if the calibration dataset is representative enough of the' ' inference dataset.') - parser.add_argument('--quantized-dtype', type=str, default='uint8', - choices=['int8', 'uint8'], + parser.add_argument('--quantized-dtype', type=str, default='auto', + choices=['auto', 'int8', 'uint8'], help='quantization destination data type for input data') args = parser.parse_args() @@ -119,12 +119,9 @@ def save_params(fname, arg_params, aux_params, logger=None): exclude_first_conv = args.exclude_first_conv excluded_sym_names = [] rgb_mean = '123,117,104' - calib_layer = lambda name: name.endswith('_output') for i in range(1,19): excluded_sym_names += ['flatten'+str(i)] - excluded_sym_names += ['relu4_3_cls_pred_conv', - 'relu7_cls_pred_conv', - 'relu4_3_loc_pred_conv'] + if exclude_first_conv: excluded_sym_names += ['conv1_1'] @@ -156,7 +153,7 @@ def save_params(fname, arg_params, aux_params, logger=None): ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, calib_data=eval_iter, num_calib_examples=num_calib_batches * batch_size, - calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, + calib_layer=None, quantized_dtype=args.quantized_dtype, label_names=(label_name,), logger=logger) sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300') param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch) diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index c3f2459c0f59..877b1300e264 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -174,7 +174,7 @@ class Executor { /*! * \brief Install a callback to notify the completion of operation. */ - virtual void SetMonitorCallback(const MonitorCallback& callback, bool monitor_input) {} + virtual void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) {} }; // class executor } // namespace mxnet #endif // MXNET_EXECUTOR_H_ From 0acd9ce4f0864dc0f96ed91dbbd660eb7128f51f Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 2 Jan 2019 14:23:05 +0800 Subject: [PATCH 09/38] Update --- .../quantization/imagenet_gen_qsym_mkldnn.py | 9 +-- include/mxnet/c_api.h | 2 +- include/mxnet/executor.h | 2 +- python/mxnet/contrib/quantization.py | 61 ++++++++----------- src/operator/nn/mkldnn/mkldnn_base.cc | 3 +- .../mkldnn/mkldnn_quantize_v2-inl.h | 24 ++++---- src/operator/quantization/quantize.cc | 2 +- src/operator/quantization/quantize_v2-inl.h | 2 +- src/operator/quantization/quantize_v2.cc | 2 +- 9 files changed, 48 insertions(+), 59 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 7fa7324beaed..561406cb3916 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -198,6 +198,7 @@ def save_params(fname, arg_params, aux_params, logger=None): # get image shape image_shape = args.image_shape + calib_layer = lambda name: name.endswith('_output') or name == "data" exclude_first_conv = args.exclude_first_conv excluded_sym_names = [] if args.model == 'imagenet1k-resnet-152': @@ -243,11 +244,7 @@ def save_params(fname, arg_params, aux_params, logger=None): rgb_mean = '0,0,0' rgb_std = '0,0,0' # add layer names you donnot want to quantize. - # add conv/pool layer names that has negative inputs - # since Intel MKL-DNN only support uint8 quantization temporary. - # add all fc layer names since Intel MKL-DNN does not support temporary. excluded_sym_names += ['layers'] - # add your first conv layer names since Intel MKL-DNN only support uint8 quantization temporary. if exclude_first_conv: excluded_sym_names += ['layers'] else: @@ -264,7 +261,7 @@ def save_params(fname, arg_params, aux_params, logger=None): mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} logger.info('rgb_std = %s' % rgb_std) rgb_std = [float(i) for i in rgb_std.split(',')] - std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]} + std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]} combine_mean_std = {} combine_mean_std.update(mean_args) combine_mean_std.update(std_args) @@ -294,7 +291,7 @@ def save_params(fname, arg_params, aux_params, logger=None): ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, calib_data=data, num_calib_examples=num_calib_batches * batch_size, - calib_layer=None, quantized_dtype=args.quantized_dtype, + calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, label_names=(label_name,), logger=logger) if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 1c7575ccd688..96b8c967e8ee 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1837,7 +1837,7 @@ MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle, - bool monitor_all); + bool monitor_all = false); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 877b1300e264..aec10091a540 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -174,7 +174,7 @@ class Executor { /*! * \brief Install a callback to notify the completion of operation. */ - virtual void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) {} + virtual void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all = false) {} }; // class executor } // namespace mxnet #endif // MXNET_EXECUTOR_H_ diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 0f32cbc82f37..352512c79fb9 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -26,7 +26,6 @@ import ctypes import logging import os -import sys import numpy as np from ..base import _LIB, check_call, py_str from ..base import c_array, c_str, mx_uint, c_str_array @@ -136,20 +135,17 @@ def __init__(self, include_layer=None, logger=None): def collect(self, name, arr): """Callback function for collecting layer output NDArrays.""" - try: - name = py_str(name) - if self.include_layer is not None and not self.include_layer(name): - return - handle = ctypes.cast(arr, NDArrayHandle) - arr = NDArray(handle, writable=False).copyto(cpu()) - if self.logger is not None: - self.logger.info("Collecting layer %s output of shape %s" % (name, arr.shape)) - if name in self.nd_dict: - self.nd_dict[name].append(arr) - else: - self.nd_dict[name] = [arr] - except KeyboardInterrupt: - sys.exit(1) + name = py_str(name) + if self.include_layer is not None and not self.include_layer(name): + return + handle = ctypes.cast(arr, NDArrayHandle) + arr = NDArray(handle, writable=False).copyto(cpu()) + if self.logger is not None: + self.logger.info("Collecting layer %s output of shape %s" % (name, arr.shape)) + if name in self.nd_dict: + self.nd_dict[name].append(arr) + else: + self.nd_dict[name] = [arr] class _LayerOutputMinMaxCollector(object): """Saves layer output min and max values in a dict with layer names as keys. @@ -162,25 +158,22 @@ def __init__(self, include_layer=None, logger=None): def collect(self, name, arr): """Callback function for collecting min and max values from an NDArray.""" - try: - name = py_str(name) - if self.include_layer is not None and not self.include_layer(name): - return - handle = ctypes.cast(arr, NDArrayHandle) - arr = NDArray(handle, writable=False) - min_range = ndarray.min(arr).asscalar() - max_range = ndarray.max(arr).asscalar() - if name in self.min_max_dict: - cur_min_max = self.min_max_dict[name] - self.min_max_dict[name] = (min(cur_min_max[0], min_range), - max(cur_min_max[1], max_range)) - else: - self.min_max_dict[name] = (min_range, max_range) - if self.logger is not None: - self.logger.info("Collecting layer %s min_range=%f, max_range=%f" - % (name, min_range, max_range)) - except KeyboardInterrupt: - sys.exit(1) + name = py_str(name) + if self.include_layer is not None and not self.include_layer(name): + return + handle = ctypes.cast(arr, NDArrayHandle) + arr = NDArray(handle, writable=False) + min_range = ndarray.min(arr).asscalar() + max_range = ndarray.max(arr).asscalar() + if name in self.min_max_dict: + cur_min_max = self.min_max_dict[name] + self.min_max_dict[name] = (min(cur_min_max[0], min_range), + max(cur_min_max[1], max_range)) + else: + self.min_max_dict[name] = (min_range, max_range) + if self.logger is not None: + self.logger.info("Collecting layer %s min_range=%f, max_range=%f" + % (name, min_range, max_range)) def _calibrate_quantized_sym(qsym, th_dict): """Given a dictionary containing the thresholds for quantizing the layers, diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 0cf44d3d78f8..d1ada8e8da6d 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -262,8 +262,7 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { const mkldnn::memory *GetWeights(const NDArray &arr, const mkldnn::memory::primitive_desc &target_pd, int num_groups) { const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd); - // If the weight array already uses the target layout, simply return it - // directly. + // If the weight array already uses the target layout, simply return it directly. if (mem) return mem; mem = GetWeights(arr, num_groups); if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_pd); diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index 806f67f01eb7..e57e649bb45f 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -43,11 +43,11 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, using namespace mxnet_op; using red::limits::MaxValue; using red::limits::MinValue; - float real_range = 0.0; - float quantized_range = 0.0; + SrcType real_range = 0.f; + DstType quantized_range = 0; NDArray in_buffer = inputs[0]; - float data_min = red::limits::MaxValue(); - float data_max = red::limits::MinValue(); + SrcType data_min = red::limits::MaxValue(); + SrcType data_max = red::limits::MinValue(); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { data_min = param.min_calib_range.value(); @@ -55,10 +55,10 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, } else { // no calib info in_buffer = inputs[0].Reorder2Default(); - auto in_ptr = in_buffer.data().dptr(); + auto in_ptr = in_buffer.data().dptr(); auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - std::vector data_maxs(nthreads, data_max); - std::vector data_mins(nthreads, data_min); + std::vector data_maxs(nthreads, data_max); + std::vector data_mins(nthreads, data_min); #pragma omp parallel for num_threads(nthreads) for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { int tid = omp_get_thread_num(); @@ -72,10 +72,10 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, } auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { - real_range = MaxAbs(data_min, data_max); - quantized_range = MaxAbs(MaxValue(), MinValue()); - *outputs[1].data().dptr() = data_min; - *outputs[2].data().dptr() = data_max; + real_range = std::max(0.f, data_max); + quantized_range = MaxValue(); + *outputs[1].data().dptr() = 0.f; + *outputs[2].data().dptr() = real_range; } else if (out_type == mshadow::kInt8) { real_range = MaxAbs(data_min, data_max); quantized_range = MinAbs(MaxValue(), MinValue()); @@ -84,7 +84,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, } else { LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; } - float scale = quantized_range / real_range; + float scale = static_cast(quantized_range) / real_range; primitive_attr attr; const int mask = 0; diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc index 5227751bc635..e486f058bfd5 100644 --- a/src/operator/quantization/quantize.cc +++ b/src/operator/quantization/quantize.cc @@ -71,7 +71,7 @@ where `scale = quantized_range / MaxAbs(min_range, max_range).` .. Note:: - This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE) + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(3) .set_num_outputs(3) diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 9ba2c3a5437c..8623189faeec 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -47,7 +47,7 @@ struct QuantizeV2Param : public dmlc::Parameter { .add_enum("auto", kAuto) .add_enum("int8", kInt8) .add_enum("uint8", kUint8) - .set_default(kUint8) + .set_default(kAuto) .describe("Output data type. `auto` can be specified to automatically determine output type " "according to min_calib_range."); DMLC_DECLARE_FIELD(min_calib_range) diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index afa341f8a780..5009bfadc20c 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -74,7 +74,7 @@ If min_calib_range < 0.0f, the output type will be int8, otherwise will be uint8 If min_calib_range isn't presented, the output type will be int8. .. Note:: - This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE) + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(3) From a713664ceeb23b0740af23e90f80ecd63c09fb07 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 3 Jan 2019 13:12:52 +0800 Subject: [PATCH 10/38] Manually cast mshadow shape size to size_t --- include/mxnet/tensor_blob.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 496e8c7cfced..412877a58218 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -287,7 +287,7 @@ class TBlob { CHECK(Device::kDevMask == this->dev_mask()) << "TBlob.get: device type do not match specified type"; CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous"; - CHECK_EQ(this->shape_.Size(), shape.Size()) + CHECK_EQ(this->shape_.Size(), static_cast(shape.Size())) << "TBlob.get_with_shape: new and old shape do not match total elements"; return mshadow::Tensor(dptr(), shape, shape[dim - 1], stream); From 9129b4183678250b0dadab6d8a34e47eee77cff6 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 3 Jan 2019 14:53:43 +0800 Subject: [PATCH 11/38] Fix merge. --- src/operator/nn/mkldnn/mkldnn_convolution.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index f0fd961c8062..c1e1433eab79 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -136,9 +136,6 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP const NDArray &weights, const NDArray *bias, const NDArray &output) { - CHECK_GE(param.conv_param.stride.ndim(), 2U); - CHECK_GE(param.conv_param.pad.ndim(), 2U); - CHECK_GE(param.conv_param.dilate.ndim(), 2U); auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.conv_param.num_group, param.mkldnn_param.quantized); auto out_md = GetMemDesc(output); From 159dc117fbe3063caf95ef248ea558a7183121f8 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 3 Jan 2019 15:19:08 +0800 Subject: [PATCH 12/38] Fix perl package. --- perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm index 573abbf588f2..190177d59c1d 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm @@ -254,14 +254,17 @@ method backward( ---------- $callback : CodeRef Takes a string and an NDArrayHandle. + $monitor_all : Bool, default 0 + If true, monitor both input and output, otherwise monitor output only. =cut -method set_monitor_callback(CodeRef $callback) +method set_monitor_callback(CodeRef $callback, Bool $monitor_all=0) { check_call( AI::MXNetCAPI::ExecutorSetMonitorCallback( $self->handle, - $callback + $callback, + $monitor_all ) ); } From 3cd9d99f69c6050347a77dd3f9e4dd17d57d1cdc Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 3 Jan 2019 16:48:54 +0800 Subject: [PATCH 13/38] Retrigger CI --- src/operator/quantization/quantize_v2.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 5009bfadc20c..4fabe393e892 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -22,6 +22,7 @@ * \file quantize.cc * \brief */ + #include "./quantize_v2-inl.h" #if MXNET_USE_MKLDNN == 1 #include "./mkldnn/mkldnn_quantize_v2-inl.h" From 0b98e9485e4eb9acec6124064812d39398f73906 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 4 Jan 2019 13:42:00 +0800 Subject: [PATCH 14/38] Fix GPU test --- src/operator/quantization/quantize_v2-inl.h | 102 ++++++++------------ src/operator/quantization/quantize_v2.cc | 8 ++ 2 files changed, 46 insertions(+), 64 deletions(-) diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 8623189faeec..f1532f047849 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -85,15 +85,13 @@ static mshadow::TypeFlag GetOutputType(const QuantizeV2Param ¶m) { struct quantize_v2_unsigned { template MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, float *omax_range, - const SrcDType *in, const float *imin_range, - const float *imax_range, const double min_limit, + const SrcDType *in, const float imin_range, + const float imax_range, const double min_limit, const double max_limit) { - using mshadow::red::limits::MaxValue; - using mshadow::red::limits::MinValue; - const float scale = (max_limit - min_limit) / (*imax_range - *imin_range); - out[i] = static_cast((in[i] - *imin_range) * scale + 0.5); - *omin_range = *imin_range; - *omax_range = *imax_range; + const float scale = (max_limit - min_limit) / (imax_range - imin_range); + out[i] = static_cast((in[i] - imin_range) * scale + 0.5); + *omin_range = imin_range; + *omax_range = imax_range; } }; @@ -101,9 +99,9 @@ struct quantize_v2_unsigned { struct quantize_v2_zero_centered { template MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, float *omax_range, - const SrcDType *in, const float *imin_range, - const float *imax_range, const float quantized_range) { - float real_range = MaxAbs(*imin_range, *imax_range); + const SrcDType *in, const float imin_range, + const float imax_range, const float quantized_range) { + float real_range = MaxAbs(imin_range, imax_range); float scale = quantized_range / real_range; SrcDType x = in[i]; out[i] = static_cast(Sign(x) * Min(Abs(x) * scale + 0.5f, quantized_range)); @@ -122,72 +120,48 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, using mshadow::red::limits::MaxValue; using mshadow::red::limits::MinValue; Stream *s = ctx.get_stream(); + SrcDType in_min; + SrcDType in_max; const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); auto out_type = GetOutputType(param); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - auto in_min = param.min_calib_range.value(); - auto in_max = param.max_calib_range.value(); - if (out_type == mshadow::kUint8) { - Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), - outputs[1].dptr(), outputs[2].dptr(), - inputs[0].dptr(), &in_min, &in_max, - MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), &in_min, &in_max, - MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; - } + in_min = param.min_calib_range.value(); + in_max = param.max_calib_range.value(); } else { // model is not calibrated TShape src_shape, dst_shape; const size_t actual_float_size = sizeof(float); - const size_t actual_quantized_size = sizeof(SrcDType); const size_t temp_reduce_size = ConfigReduce(s, inputs[0].shape_, TShape({1}), &src_shape, &dst_shape); Tensor temp_space = ctx.requested[0].get_space_typed( - Shape1(2 * actual_float_size + 2 * actual_quantized_size + temp_reduce_size), s); - Tensor actual_min_float(reinterpret_cast(temp_space.dptr_), Shape1(1), - s); - Tensor actual_max_float(reinterpret_cast(temp_space.dptr_) + 1, - Shape1(1), s); - + Shape1(2 * actual_float_size + temp_reduce_size), s); const int dev_id = ctx.run_ctx.ctx.dev_id; - TBlob actual_min_quantized(reinterpret_cast(temp_space.dptr_ + 8), Shape1(1), - xpu::kDevMask, dev_id); - TBlob actual_max_quantized(reinterpret_cast(temp_space.dptr_ + 8) + 1, Shape1(1), - xpu::kDevMask, dev_id); - Tensor workspace( - temp_space.dptr_ + 2 * actual_float_size + 2 * actual_quantized_size, - Shape1(temp_reduce_size), s); + TBlob in_min_t(reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, + dev_id); + TBlob in_max_t(reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, + dev_id); + Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, + Shape1(temp_reduce_size), s); broadcast::Reduce( - s, actual_min_quantized.reshape(dst_shape), kWriteTo, workspace, - inputs[0].reshape(src_shape)); - Kernel::Launch(s, 1, actual_min_float.dptr_, - actual_min_quantized.dptr(), - inputs[1].dptr(), inputs[2].dptr()); - + s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); broadcast::Reduce( - s, actual_max_quantized.reshape(dst_shape), kWriteTo, workspace, - inputs[0].reshape(src_shape)); - Kernel::Launch(s, 1, actual_max_float.dptr_, - actual_max_quantized.dptr(), - inputs[1].dptr(), inputs[2].dptr()); - if (out_type == mshadow::kUint8) { - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), actual_min_float.dptr_, - actual_max_float.dptr_, MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), actual_min_float.dptr_, - actual_max_float.dptr_, MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; - } + s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + in_min = *in_min_t.dptr(); + in_max = *in_max_t.dptr(); + } + + if (out_type == mshadow::kUint8) { + Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), + outputs[1].dptr(), outputs[2].dptr(), + inputs[0].dptr(), in_min, in_max, + MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min, in_max, + MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; } } diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 4fabe393e892..7f773a8ceeed 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -91,6 +91,14 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr("FComputeEx", MKLDNNQuantizeV2Compute) #endif .set_attr("FCompute", QuantizeV2Compute) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) { + const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + return std::vector(); + } else { + return std::vector(1, ResourceRequest::kTempSpace); + } + }) .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") .add_arguments(QuantizeV2Param::__FIELDS__()); From 989477b2b36d585e6abafaff8a9449e66de9b0b5 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 4 Jan 2019 16:33:43 +0800 Subject: [PATCH 15/38] Fix GPU test --- src/operator/quantization/quantize_v2-inl.h | 62 ++++++++++++++------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index f1532f047849..95d5a75e835d 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -93,6 +93,14 @@ struct quantize_v2_unsigned { *omin_range = imin_range; *omax_range = imax_range; } + + template + MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, float *omax_range, + const SrcDType *in, const float *imin_range, + const float *imax_range, const double min_limit, + const double max_limit) { + Map(i, out, omin_range, omax_range, in, *imin_range, *imax_range, min_limit, max_limit); + } }; // keep zero-center @@ -108,6 +116,13 @@ struct quantize_v2_zero_centered { *omin_range = -real_range; *omax_range = real_range; } + + template + MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, float *omax_range, + const SrcDType *in, const float *imin_range, + const float *imax_range, const float quantized_range) { + Map(i, out, omin_range, omax_range, in, *imin_range, *imax_range, quantized_range); + } }; template @@ -120,14 +135,22 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, using mshadow::red::limits::MaxValue; using mshadow::red::limits::MinValue; Stream *s = ctx.get_stream(); - SrcDType in_min; - SrcDType in_max; - const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); auto out_type = GetOutputType(param); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - in_min = param.min_calib_range.value(); - in_max = param.max_calib_range.value(); + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } } else { // model is not calibrated TShape src_shape, dst_shape; const size_t actual_float_size = sizeof(float); @@ -146,22 +169,19 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); broadcast::Reduce( s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); - in_min = *in_min_t.dptr(); - in_max = *in_max_t.dptr(); - } - - if (out_type == mshadow::kUint8) { - Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), - outputs[1].dptr(), outputs[2].dptr(), - inputs[0].dptr(), in_min, in_max, - MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), in_min, in_max, - MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } } } From 894f13f3f1a5cfe54603bfcdb36bb6f2cadda213 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 4 Jan 2019 20:06:56 +0800 Subject: [PATCH 16/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e57e649bb45f..f32ae00816c8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -70,6 +70,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, if (data_mins[i] < data_min) data_min = data_mins[i]; } } + auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { real_range = std::max(0.f, data_max); From 60c16a7dcdc8289f7ae0c556141420d61e135c94 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 4 Jan 2019 21:30:46 +0800 Subject: [PATCH 17/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index f32ae00816c8..e57e649bb45f 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -70,7 +70,6 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, if (data_mins[i] < data_min) data_min = data_mins[i]; } } - auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { real_range = std::max(0.f, data_max); From 95ab7a13dd3d3cac686cf45b559d554526ebdfc3 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 5 Jan 2019 08:43:38 +0800 Subject: [PATCH 18/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e57e649bb45f..f32ae00816c8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -70,6 +70,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, if (data_mins[i] < data_min) data_min = data_mins[i]; } } + auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { real_range = std::max(0.f, data_max); From c861cfcc77271568d403d604a712f7c0249b5a8a Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 5 Jan 2019 09:49:58 +0800 Subject: [PATCH 19/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index f32ae00816c8..e57e649bb45f 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -70,7 +70,6 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, if (data_mins[i] < data_min) data_min = data_mins[i]; } } - auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { real_range = std::max(0.f, data_max); From 4dacf2f3747247cad057a47dc7f77296019d914c Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 7 Jan 2019 12:37:48 +0800 Subject: [PATCH 20/38] Remove weight_channelwise_scale from params. --- src/operator/nn/mkldnn/mkldnn_convolution-inl.h | 3 --- src/operator/nn/mkldnn/mkldnn_convolution.cc | 2 +- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 10 ++++------ 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index a27dced910cb..ab6650eadad7 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -42,7 +42,6 @@ struct MKLDNNConvParam : public dmlc::Parameter { bool with_sum; bool with_postsum_relu; bool quantized; - bool weight_channelwise_scale; dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset @@ -58,8 +57,6 @@ struct MKLDNNConvParam : public dmlc::Parameter { .describe("Add post relu after sum"); DMLC_DECLARE_FIELD(quantized).set_default(false) .describe("enable quantization"); - DMLC_DECLARE_FIELD(weight_channelwise_scale).set_default(true) - .describe("Quantize weight with channel wise scales."); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) .describe("The minimum scalar value in the form of float32 obtained " diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index c1e1433eab79..96cb2de81987 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -91,7 +91,7 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( attr.set_post_ops(ops); if (param.mkldnn_param.quantized && param.requantize_scales.size()) { - int mask = param.mkldnn_param.weight_channelwise_scale ? 2 : 0; + int mask = (param.requantize_scales.size() > 1) ? 2 : 0; attr.set_output_scales(mask, param.requantize_scales); attr.set_int_output_round_mode(round_nearest); } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 8b72aec2ce43..04f95d6c3783 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -307,19 +307,18 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, CHECK_EQ(data.dtype(), mshadow::kInt8) << "Expect int8 when data_min < 0.0, consider quantize model with int8."; } + auto weight_channelwise_scale = false; if (mkldnn_param.min_calib_range.has_value() && mkldnn_param.max_calib_range.has_value()) { cached_output_min_ = mkldnn_param.min_calib_range.value(); cached_output_max_ = mkldnn_param.max_calib_range.value(); post_requantize_ = true; - mkldnn_param.weight_channelwise_scale = true; - } else { - mkldnn_param.weight_channelwise_scale = false; + weight_channelwise_scale = true; } auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range; data_scale_ = data_range / MaxAbs(cached_data_min_, cached_data_max_); MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { weight_scales_ = - GetWeightScales(cached_weight_, mkldnn_param.weight_channelwise_scale); + GetWeightScales(cached_weight_, weight_channelwise_scale); }); // Collect scale. size_t channel = cached_weight_.shape()[0]; @@ -335,8 +334,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, quantized_out_range = IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range; out_range = MaxAbs(cached_output_min_, cached_output_max_); output_scale = quantized_out_range / out_range; - full_conv_param.requantize_scales.resize(mkldnn_param.weight_channelwise_scale ? channel - : 1); + full_conv_param.requantize_scales.resize(weight_channelwise_scale ? channel : 1); for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) { full_conv_param.requantize_scales[c] = output_scale / data_scale_ / weight_scales_[c]; } From 0a612e6ddf326a4362fdd67e48e2d0fe4c877078 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 10 Jan 2019 14:55:38 +0800 Subject: [PATCH 21/38] Fix --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 3 +-- src/operator/subgraph/mkldnn/mkldnn_conv_property.cc | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index ee2bba306397..916036e175a6 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -271,8 +271,7 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, static_cast(arr.shape()[C]), static_cast(arr.shape()[H]), static_cast(arr.shape()[W])}; } - return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()), - mkldnn::memory::format::any}; + return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format::any}; } } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index adfc41bb120f..e462191c2898 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -141,8 +141,7 @@ class SgMKLDNNConvProperty : public SubgraphProperty { disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0); disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0); - disable_all = - disable_all && disable_conv_bn && disable_conv_relu && disable_conv_sum; + disable_all = disable_all || (disable_conv_bn && disable_conv_relu && disable_conv_sum); if (disable_all) { LOG(INFO) << "MKLDNN Convolution optimization pass is disabled."; } else { From 1f7fc56b020bc363c5a1a1ce54c49014d239a741 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 14 Jan 2019 13:54:15 +0800 Subject: [PATCH 22/38] Keep API compatible. --- include/mxnet/c_api.h | 5 +++-- include/mxnet/ndarray.h | 7 +++++++ src/c_api/c_api_symbolic.cc | 3 ++- src/ndarray/ndarray.cc | 12 ++++++++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 96b8c967e8ee..61f588ad29c5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1556,12 +1556,13 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data. + * \param calib_quantize **Deperated**. quantize op will always be calibrated if could. */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, const char **excluded_symbols, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype); + const char *quantized_dtype, const bool calib_quantize); /*! * \brief Set calibration table to node attributes in the sym @@ -1837,7 +1838,7 @@ MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle, - bool monitor_all = false); + bool monitor_all); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 5de42e19a657..5f2567e97d3d 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -691,6 +691,13 @@ class NDArray { } #if MXNET_USE_MKLDNN == 1 + /* + * Create NDArray from mkldnn memory. + * mkldnn_mem The mkldnn memory to be managed. + * static_data If true, mkldnn memory won't be freed on destruction. + */ + explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true); + /* * Create NDArray from mkldnn memory. * mkldnn_mem The mkldnn memory to be managed. diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 0a49b88e5429..0901aa90de1e 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -650,7 +650,8 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const char **excluded_op_names, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype) { + const char *quantized_dtype, + const bool calib_quantize) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 0f0fed24d4e6..93a8c2850587 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -168,6 +168,18 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { #if MXNET_USE_MKLDNN == 1 +NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data) + : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + auto mem_pd = mkldnn_mem->get_primitive_desc(); + auto mem_desc = mem_pd.desc(); + shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); + dtype_ = get_mxnet_type(mem_desc.data.data_type); + auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_); + ptr_ = std::make_shared(data, 0); + ptr_->mkl_mem_ = std::make_shared(mem_pd, ptr_->shandle.dptr); + ptr_->static_data = static_data; +} + NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { auto mem_desc = mem_pd.desc(); From cffaa05f10ddba6e83cf884a969dc995db28070e Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 15 Jan 2019 09:11:22 +0800 Subject: [PATCH 23/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e57e649bb45f..f32ae00816c8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -70,6 +70,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, if (data_mins[i] < data_min) data_min = data_mins[i]; } } + auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { real_range = std::max(0.f, data_max); From 12d9f1fbd461dd537d3c6fe537181d4e34cf5515 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 15 Jan 2019 12:12:14 +0800 Subject: [PATCH 24/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index f32ae00816c8..e57e649bb45f 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -70,7 +70,6 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, if (data_mins[i] < data_min) data_min = data_mins[i]; } } - auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { real_range = std::max(0.f, data_max); From 807a48c94ca046d68268f1933b45b322eb779943 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 15 Jan 2019 12:56:22 +0800 Subject: [PATCH 25/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e57e649bb45f..f32ae00816c8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -70,6 +70,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, if (data_mins[i] < data_min) data_min = data_mins[i]; } } + auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { real_range = std::max(0.f, data_max); From 62e3fc071da51cf9b2ac99b01cee715a3d37afa5 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 28 Jan 2019 20:45:18 +0800 Subject: [PATCH 26/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index f32ae00816c8..e201d290e8c6 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -48,7 +48,6 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, NDArray in_buffer = inputs[0]; SrcType data_min = red::limits::MaxValue(); SrcType data_max = red::limits::MinValue(); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { data_min = param.min_calib_range.value(); data_max = param.max_calib_range.value(); From 69a6e2828a147dd08ad5918385d8864cdce3e0e1 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 29 Jan 2019 20:59:15 +0800 Subject: [PATCH 27/38] Address comments. --- include/mxnet/c_api.h | 2 +- .../mkldnn/mkldnn_dequantize-inl.h | 1 + src/operator/quantization/quantize_v2-inl.h | 6 ++-- src/operator/quantization/quantize_v2.cc | 29 +++++++++---------- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 402ef43f49fd..2f15fadc81f0 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1566,7 +1566,7 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data. - * \param calib_quantize **Deperated**. quantize op will always be calibrated if could. + * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could. */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h index 3c65172c6116..b66adf787fef 100644 --- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h @@ -75,6 +75,7 @@ static void MKLDNNDequantizeComputeKer(const std::vector &inputs, } mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); if (i_fmt == mkldnn::memory::format::nhwc) { + // For 4d tensor, nchw is the default format i_fmt = mkldnn::memory::format::nchw; } auto o_desc = mkldnn::memory::desc(i_dims, diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 95d5a75e835d..5ae10a7e4fa8 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -47,7 +47,7 @@ struct QuantizeV2Param : public dmlc::Parameter { .add_enum("auto", kAuto) .add_enum("int8", kInt8) .add_enum("uint8", kUint8) - .set_default(kAuto) + .set_default(kInt8) .describe("Output data type. `auto` can be specified to automatically determine output type " "according to min_calib_range."); DMLC_DECLARE_FIELD(min_calib_range) @@ -76,7 +76,7 @@ static mshadow::TypeFlag GetOutputType(const QuantizeV2Param ¶m) { } else if (param.out_type == QuantizeV2Param::OutType::kUint8) { out_type = mshadow::kUint8; } else { - LOG(FATAL) << "Unsupported quantize output type."; + LOG(FATAL) << "Unsupported out_type in params: " < } else if (out_type == mshadow::kInt8) { TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8); } else { - LOG(FATAL) << "Unsupported out_type."; + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; } TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 7f773a8ceeed..21410933d35e 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -32,11 +32,9 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(QuantizeV2Param); -static bool QuantizeV2StorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { +static bool QuantizeV2StorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, + DispatchMode* dispatch_mode, std::vector* in_attrs, + std::vector* out_attrs) { *dispatch_mode = DispatchMode::kFCompute; #if MXNET_USE_MKLDNN == 1 if (dev_mask == mshadow::cpu::kDevMask) { @@ -79,10 +77,9 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(3) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data"}; - }) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) .set_attr("FInferShape", QuantizeV2Shape) .set_attr("FInferType", QuantizeV2Type) .set_attr("FInferStorageType", QuantizeV2StorageType) @@ -92,13 +89,13 @@ If min_calib_range isn't presented, the output type will be int8. #endif .set_attr("FCompute", QuantizeV2Compute) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - return std::vector(); - } else { - return std::vector(1, ResourceRequest::kTempSpace); - } - }) + const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + return std::vector(); + } else { + return std::vector(1, ResourceRequest::kTempSpace); + } +}) .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") .add_arguments(QuantizeV2Param::__FIELDS__()); From bf655e2c8635930dc705544b5318fbc8cff4b6e9 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 1 Feb 2019 15:42:12 +0800 Subject: [PATCH 28/38] fix. --- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 04f95d6c3783..499d7390eaad 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -356,7 +356,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, initalized_ = true; } - if (!mkldnn_param.quantized) { + if (mkldnn_param.quantized) { auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_primitive_desc()); mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->fwd_pd.dst_primitive_desc()); fwd_->SetNewMem(*data_mem, *mem); From e3a8d0a59cac43056f76e822a4e8745ebb44c195 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 2 Feb 2019 11:49:51 +0800 Subject: [PATCH 29/38] Address debug build. --- src/operator/nn/mkldnn/mkldnn_convolution.cc | 94 ++++++++++---------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 96cb2de81987..b52f753dd4d9 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -45,12 +45,21 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { (input.shape().ndim() == 4)); } -static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const MKLDNNConvFullParam ¶m, const bool is_train, const mkldnn::memory::desc &data_md, - const mkldnn::memory::desc &weight_md, const mkldnn::memory::desc *bias_md, - const mkldnn::memory::desc &out_md) { - auto engine = CpuEngine::Get()->get_engine(); +mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, + const bool is_train, const NDArray &data, + const NDArray &weights, + const NDArray *bias, + const NDArray &output) { auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; + auto data_md = GetMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.conv_param.num_group, param.mkldnn_param.quantized); + auto out_md = GetMemDesc(output); + auto bias_md = + bias ? (param.mkldnn_param.quantized ? GetMemDesc(*bias, mshadow::kInt32) : GetMemDesc(*bias)) + : mkldnn::memory::desc{ + {}, mkldnn::memory::data_type::data_undef, mkldnn::memory::format::any}; + auto bias_md_ptr = bias ? &bias_md : nullptr; + mkldnn::memory::dims strides(param.conv_param.kernel.ndim()); mkldnn::memory::dims padding(param.conv_param.kernel.ndim()); if (param.conv_param.kernel.ndim() == 1) { @@ -95,17 +104,39 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( attr.set_output_scales(mask, param.requantize_scales); attr.set_int_output_round_mode(round_nearest); } + auto GetConvFwdPd = [¶m, &data, &weights, &output, + &attr](const mkldnn::convolution_forward::desc &desc) { + auto engine = CpuEngine::Get()->get_engine(); + try { + auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || + conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || + (!param.mkldnn_param.quantized && + conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights))) { + CHECK(conv_pd.next_impl()) << "No convolution implementation for this request."; + } + return conv_pd; + } catch (mkldnn::error &e) { + if (e.status == mkldnn_unimplemented && param.mkldnn_param.quantized) { + LOG(ERROR) << "AVX512-BW support or Intel(R) MKL dependency is " + "required for int8 convolution"; + } else { + LOG(ERROR) << e.message; + } + throw; + } + }; - if (param.conv_param.dilate.ndim() == 0 && bias_md == nullptr) { + if (param.conv_param.dilate.ndim() == 0 && bias_md_ptr == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + return GetConvFwdPd(desc); } else if (param.conv_param.dilate.ndim() == 0) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, - weight_md, *bias_md, out_md, strides, padding, padding, - mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + weight_md, *bias_md_ptr, out_md, strides, padding, + padding, mkldnn::padding_kind::zero); + return GetConvFwdPd(desc); } else { mkldnn::memory::dims dilates(param.conv_param.kernel.ndim()); if (param.conv_param.dilate.ndim() == 1) { @@ -117,50 +148,17 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size " << param.conv_param.dilate.ndim() << ", supporting only 1 or 2."; } - if (bias_md == nullptr) { + if (bias_md_ptr == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + return GetConvFwdPd(desc); } else { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, - weight_md, *bias_md, out_md, strides, dilates, padding, - padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); - } - } -} - -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, - const bool is_train, const NDArray &data, - const NDArray &weights, - const NDArray *bias, - const NDArray &output) { - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.conv_param.num_group, param.mkldnn_param.quantized); - auto out_md = GetMemDesc(output); - auto bias_md = - bias ? (param.mkldnn_param.quantized ? GetMemDesc(*bias, mshadow::kInt32) : GetMemDesc(*bias)) - : mkldnn::memory::desc{ - {}, mkldnn::memory::data_type::data_undef, mkldnn::memory::format::any}; - auto bias_md_ptr = bias ? &bias_md : nullptr; - try { - auto conv_pd = GetConvFwdImpl(param, is_train, data_md, weight_md, bias_md_ptr, out_md); - while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - (!param.mkldnn_param.quantized && - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights))) { - CHECK(conv_pd.next_impl()) << "No convolution implementation for this request."; - } - return conv_pd; - } catch (mkldnn::error &e) { - if (e.status == mkldnn_unimplemented && param.mkldnn_param.quantized) { - LOG(ERROR) << "AVX512-BW support or Intel(R) MKL dependency is " - "required for int8 convolution"; - } else { - LOG(ERROR) << e.message; + weight_md, *bias_md_ptr, out_md, strides, dilates, + padding, padding, mkldnn::padding_kind::zero); + return GetConvFwdPd(desc); } - throw; } } From d58311b69580c369cee548186d88dc759fc9de21 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 2 Feb 2019 12:03:44 +0800 Subject: [PATCH 30/38] Add comment for next_impl --- src/operator/nn/mkldnn/mkldnn_convolution.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index b52f753dd4d9..a3aca98d9f81 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -113,6 +113,7 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || (!param.mkldnn_param.quantized && conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights))) { + // next_impl() will visit desc and engine, please make sure they are still alive here. CHECK(conv_pd.next_impl()) << "No convolution implementation for this request."; } return conv_pd; From 93295ab94348bffd38eb16380f1c8225923cccc8 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 2 Feb 2019 15:48:22 +0800 Subject: [PATCH 31/38] Rerun ci --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e201d290e8c6..f32ae00816c8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -48,6 +48,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, NDArray in_buffer = inputs[0]; SrcType data_min = red::limits::MaxValue(); SrcType data_max = red::limits::MinValue(); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { data_min = param.min_calib_range.value(); data_max = param.max_calib_range.value(); From eae2557769857c420aec2046ee50f67f98df4adf Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 4 Feb 2019 12:08:13 +0800 Subject: [PATCH 32/38] Add new api MXExecutorSetMonitorCallbackEX --- cpp-package/include/mxnet-cpp/monitor.hpp | 6 +++--- include/mxnet/c_api.h | 12 ++++++++--- .../AI-MXNet/lib/AI/MXNet/Executor.pm | 7 ++----- perl-package/AI-MXNetCAPI/mxnet.i | 5 ++--- python/mxnet/contrib/quantization.py | 2 -- python/mxnet/executor.py | 2 +- .../native/org_apache_mxnet_native_c_api.cc | 3 +-- src/c_api/c_api_executor.cc | 21 +++++++++++++++---- 8 files changed, 35 insertions(+), 23 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/monitor.hpp b/cpp-package/include/mxnet-cpp/monitor.hpp index bd7f1927e906..4439e1bd3a7c 100644 --- a/cpp-package/include/mxnet-cpp/monitor.hpp +++ b/cpp-package/include/mxnet-cpp/monitor.hpp @@ -44,9 +44,9 @@ inline Monitor::Monitor(int interval, std::regex pattern, StatFunc stat_func) } inline void Monitor::install(Executor *exe, bool monitor_all) { - MXExecutorSetMonitorCallback(exe->handle_, - static_cast(&Monitor::executor_callback), - this, monitor_all); + MXExecutorSetMonitorCallbackEX(exe->handle_, + static_cast(&Monitor::executor_callback), + this, monitor_all); exes.push_back(exe); } diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 2f15fadc81f0..d6e13ebcf051 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1843,12 +1843,18 @@ MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, /*! * \brief set a call back to notify the completion of operation - * \param monitor_all If true, monitor both input and output, otherwise monitor output only. */ MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, - void* callback_handle, - bool monitor_all); + void* callback_handle); + +/*! + * \brief set a call back to notify the completion of operation + * \param monitor_all If true, monitor both input and output, otherwise monitor output only. + */ +MXNET_DLL int MXExecutorSetMonitorCallbackEX(ExecutorHandle handle, + ExecutorMonitorCallback callback, + void *callback_handle, bool monitor_all); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm index 190177d59c1d..573abbf588f2 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm @@ -254,17 +254,14 @@ method backward( ---------- $callback : CodeRef Takes a string and an NDArrayHandle. - $monitor_all : Bool, default 0 - If true, monitor both input and output, otherwise monitor output only. =cut -method set_monitor_callback(CodeRef $callback, Bool $monitor_all=0) +method set_monitor_callback(CodeRef $callback) { check_call( AI::MXNetCAPI::ExecutorSetMonitorCallback( $self->handle, - $callback, - $monitor_all + $callback ) ); } diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index ca6623572dfb..0e6a05ea9695 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -1614,12 +1614,11 @@ int MXExecutorReshape(int partial_shaping, /*! * \brief set a call back to notify the completion of operation - * \param monitor_all If true, monitor both input and output, otherwise monitor output only. */ int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, - void* callback_handle, - bool monitor_all); + void* callback_handle); + //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 352512c79fb9..959f6e95dde8 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -283,8 +283,6 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): min_val = np.min(arr) max_val = np.max(arr) th = max(abs(min_val), abs(max_val)) - if min_val >= 0: - num_quantized_bins = (num_quantized_bins // 2) * 4 + 1 hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) zero_bin_idx = num_bins // 2 diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index ddb2dab1098e..7bf867579d6b 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -253,7 +253,7 @@ def set_monitor_callback(self, callback, monitor_all=False): """ cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p) self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) - check_call(_LIB.MXExecutorSetMonitorCallback( + check_call(_LIB.MXExecutorSetMonitorCallbackEX( self.handle, self._monitor_callback, None, diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index a1141274ea73..ea6e9c8f5ba4 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -925,8 +925,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallbac jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj); return MXExecutorSetMonitorCallback(reinterpret_cast(executorPtr), ExecutorMonitorCallbackFunc, - reinterpret_cast(callbackFuncObjGlb), - false); + reinterpret_cast(callbackFuncObjGlb)); } JNIEXPORT jstring JNICALL Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index b15f2d508644..66566ed703eb 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -645,12 +645,25 @@ int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, API_END_HANDLE_ERROR(delete s); } - - int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, - void* callback_handle, - bool monitor_all) { + void* callback_handle) { + API_BEGIN(); + ExecutorMonitorCallback callback_temp = callback; + void* callback_handle_temp = callback_handle; + std::function clbk + = [callback_temp, callback_handle_temp](const char *name, void* handle) { + callback_temp(name, handle, callback_handle_temp); + }; + Executor *exec = static_cast(handle); + exec->SetMonitorCallback(clbk, false); + API_END(); +} + +int MXExecutorSetMonitorCallbackEX(ExecutorHandle handle, + ExecutorMonitorCallback callback, + void* callback_handle, + bool monitor_all) { API_BEGIN(); ExecutorMonitorCallback callback_temp = callback; void* callback_handle_temp = callback_handle; From 11217c29446b3db84d221fd52f5863c007c6bd1f Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 4 Feb 2019 12:18:58 +0800 Subject: [PATCH 33/38] Add default value for monitor_all for cpp header. --- src/executor/graph_executor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 722714716aa4..c899a6f5b463 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -68,7 +68,7 @@ class GraphExecutor : public Executor { const std::unordered_map& arg_grad_map() const override; const std::unordered_map& aux_state_map() const override; void Print(std::ostream &os) const override; // NOLINT(*) - void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) override; + void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all = false) override; // Initialize the rest of attributes // after setting up arguments. void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, From bfc91a65565b1e026e41a8ee1875db56b5d7ef09 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 5 Feb 2019 08:55:33 +0800 Subject: [PATCH 34/38] Rerun CI --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index f32ae00816c8..e201d290e8c6 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -48,7 +48,6 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, NDArray in_buffer = inputs[0]; SrcType data_min = red::limits::MaxValue(); SrcType data_max = red::limits::MinValue(); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { data_min = param.min_calib_range.value(); data_max = param.max_calib_range.value(); From fe08128d6fc1ae880456b8ab155d791b50026224 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 5 Feb 2019 19:26:00 +0800 Subject: [PATCH 35/38] fix --- python/mxnet/contrib/quantization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 959f6e95dde8..96183bb7a172 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -119,7 +119,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_ c_str_array(excluded_symbols), mx_uint(num_offline), c_array(ctypes.c_char_p, offline), - c_str(quantized_dtype))) + c_str(quantized_dtype), + ctypes.c_bool(True))) return Symbol(out) From 63dfdbf71a2ab9141511b3857265aa367b340148 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 5 Feb 2019 20:37:38 +0800 Subject: [PATCH 36/38] script change for uint8. --- example/quantization/imagenet_gen_qsym_mkldnn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 561406cb3916..d807e7f2d19d 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -200,6 +200,9 @@ def save_params(fname, arg_params, aux_params, logger=None): calib_layer = lambda name: name.endswith('_output') or name == "data" exclude_first_conv = args.exclude_first_conv + if args.quantized_dtype == "uint8": + logger.info('quantized dtype is set to uint8, will exclude first conv.') + exclude_first_conv = True excluded_sym_names = [] if args.model == 'imagenet1k-resnet-152': rgb_mean = '0,0,0' From 1210b5c250bf98674fcc463cafb1743d2557f743 Mon Sep 17 00:00:00 2001 From: xinyu-intel Date: Sat, 9 Feb 2019 00:15:33 +0800 Subject: [PATCH 37/38] trigger ci --- trigger | 1 + 1 file changed, 1 insertion(+) create mode 100644 trigger diff --git a/trigger b/trigger new file mode 100644 index 000000000000..f3b7704c6f02 --- /dev/null +++ b/trigger @@ -0,0 +1 @@ +trigger ci From bff42ff55a5885946549a3a914f54a8c2d03d9fa Mon Sep 17 00:00:00 2001 From: xinyu-intel Date: Sat, 9 Feb 2019 00:16:37 +0800 Subject: [PATCH 38/38] trigger ci --- trigger | 1 - 1 file changed, 1 deletion(-) delete mode 100644 trigger diff --git a/trigger b/trigger deleted file mode 100644 index f3b7704c6f02..000000000000 --- a/trigger +++ /dev/null @@ -1 +0,0 @@ -trigger ci