diff --git a/Jenkinsfile b/Jenkinsfile index f9e038f0c3cd..f3329cdb796f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -386,6 +386,24 @@ try { } } }, + 'Python2: Quantize GPU': { + node('mxnetlinux-gpu-p3') { + ws('workspace/ut-python2-quantize-gpu') { + init_git() + unpack_lib('gpu', mx_lib) + sh "ci/build.py --nvidiadocker --build --platform ubuntu_gpu /work/runtime_functions.sh unittest_ubuntu_python2_quantization_gpu" + } + } + }, + 'Python3: Quantize GPU': { + node('mxnetlinux-gpu-p3') { + ws('workspace/ut-python3-quantize-gpu') { + init_git() + unpack_lib('gpu', mx_lib) + sh "ci/build.py --nvidiadocker --build --platform ubuntu_gpu /work/runtime_functions.sh unittest_ubuntu_python3_quantization_gpu" + } + } + }, 'Python2: MKLDNN-CPU': { node('mxnetlinux-cpu') { ws('workspace/ut-python2-mkldnn-cpu') { diff --git a/benchmark/python/quantization/benchmark_op.py b/benchmark/python/quantization/benchmark_op.py new file mode 100644 index 000000000000..5ba7740cc918 --- /dev/null +++ b/benchmark/python/quantization/benchmark_op.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +import mxnet as mx +from mxnet.test_utils import check_speed + + +def quantize_int8_helper(data): + min_data = mx.nd.min(data) + max_data = mx.nd.max(data) + return mx.nd.contrib.quantize(data, min_data, max_data, out_type='int8') + + +def benchmark_convolution(data_shape, kernel, num_filter, pad, stride, no_bias=True, layout='NCHW', repeats=20): + ctx_gpu = mx.gpu(0) + data = mx.sym.Variable(name="data", shape=data_shape, dtype='float32') + # conv cudnn + conv_cudnn = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride, + no_bias=no_bias, layout=layout, cudnn_off=False, name="conv_cudnn") + arg_shapes, _, _ = conv_cudnn.infer_shape(data=data_shape) + input_data = mx.nd.random.normal(0, 0.2, shape=data_shape, ctx=ctx_gpu) + conv_weight_name = conv_cudnn.list_arguments()[1] + args = {data.name: input_data, conv_weight_name: mx.random.normal(0, 1, shape=arg_shapes[1], ctx=ctx_gpu)} + conv_cudnn_time = check_speed(sym=conv_cudnn, location=args, ctx=ctx_gpu, N=repeats, + grad_req='null', typ='forward') * 1000 + + # quantized_conv2d + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8') + weight = mx.sym.Variable(name='weight', shape=arg_shapes[1], dtype='int8') + min_data = mx.sym.Variable(name='min_data', shape=(1,), dtype='float32') + max_data = mx.sym.Variable(name='max_data', shape=(1,), dtype='float32') + min_weight = mx.sym.Variable(name='min_weight', shape=(1,), dtype='float32') + max_weight = mx.sym.Variable(name='max_weight', shape=(1,), dtype='float32') + quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, weight=weight, min_data=min_data, max_data=max_data, + min_weight=min_weight, max_weight=max_weight, + kernel=kernel, num_filter=num_filter, pad=pad, stride=stride, + no_bias=no_bias, layout=layout, cudnn_off=False, + name='quantized_conv2d') + qargs = {qdata.name: quantize_int8_helper(input_data)[0], + min_data.name: quantize_int8_helper(input_data)[1], + max_data.name: quantize_int8_helper(input_data)[2], + weight.name: quantize_int8_helper(args[conv_weight_name])[0], + min_weight.name: quantize_int8_helper(args[conv_weight_name])[1], + max_weight.name: quantize_int8_helper(args[conv_weight_name])[2]} + qconv_time = check_speed(sym=quantized_conv2d, location=qargs, ctx=ctx_gpu, N=repeats, + grad_req='null', typ='forward') * 1000 + + print('==================================================================================================') + print('data=%s, kernel=%s, num_filter=%s, pad=%s, stride=%s, no_bias=%s, layout=%s, repeats=%s' + % (data_shape, kernel, num_filter, pad, stride, no_bias, layout, repeats)) + print('%s , ctx=%s, time=%.2f ms' % (conv_cudnn.name + '-FP32', ctx_gpu, conv_cudnn_time)) + print('%s, ctx=%s, time=%.2f ms' % (quantized_conv2d.name, ctx_gpu, qconv_time)) + print('quantization speedup: %.1fX' % (conv_cudnn_time / qconv_time)) + print('\n') + + +if __name__ == '__main__': + for batch_size in [32, 64, 128]: + benchmark_convolution(data_shape=(batch_size, 64, 56, 56), kernel=(1, 1), num_filter=256, + pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20) + + benchmark_convolution(data_shape=(batch_size, 256, 56, 56), kernel=(1, 1), num_filter=64, + pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20) + + benchmark_convolution(data_shape=(batch_size, 256, 56, 56), kernel=(1, 1), num_filter=128, + pad=(0, 0), stride=(2, 2), layout='NCHW', repeats=20) + + benchmark_convolution(data_shape=(batch_size, 128, 28, 28), kernel=(3, 3), num_filter=128, + pad=(1, 1), stride=(1, 1), layout='NCHW', repeats=20) + + benchmark_convolution(data_shape=(batch_size, 1024, 14, 14), kernel=(1, 1), num_filter=256, + pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20) + + benchmark_convolution(data_shape=(batch_size, 2048, 7, 7), kernel=(1, 1), num_filter=512, + pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index abb37022d668..1d005ce6dc69 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -361,6 +361,7 @@ unittest_ubuntu_python2_cpu() { export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 nosetests-2.7 --verbose tests/python/unittest nosetests-2.7 --verbose tests/python/train + nosetests-2.7 --verbose tests/python/quantization } unittest_ubuntu_python3_cpu() { @@ -371,6 +372,7 @@ unittest_ubuntu_python3_cpu() { #export MXNET_MKLDNN_DEBUG=1 # Ignored if not present export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 nosetests-3.4 --verbose tests/python/unittest + nosetests-3.4 --verbose tests/python/quantization } unittest_ubuntu_python2_gpu() { @@ -393,6 +395,30 @@ unittest_ubuntu_python3_gpu() { nosetests-3.4 --verbose tests/python/gpu } +# quantization gpu currently only runs on P3 instances +# need to separte it from unittest_ubuntu_python2_gpu() +unittest_ubuntu_python2_quantization_gpu() { + set -ex + export PYTHONPATH=./python/ + # MXNET_MKLDNN_DEBUG is buggy and produces false positives + # https://github.com/apache/incubator-mxnet/issues/10026 + #export MXNET_MKLDNN_DEBUG=1 # Ignored if not present + export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 + nosetests-2.7 --verbose tests/python/quantization_gpu +} + +# quantization gpu currently only runs on P3 instances +# need to separte it from unittest_ubuntu_python3_gpu() +unittest_ubuntu_python3_quantization_gpu() { + set -ex + export PYTHONPATH=./python/ + # MXNET_MKLDNN_DEBUG is buggy and produces false positives + # https://github.com/apache/incubator-mxnet/issues/10026 + #export MXNET_MKLDNN_DEBUG=1 # Ignored if not present + export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 + nosetests-3.4 --verbose tests/python/quantization_gpu +} + unittest_ubuntu_cpu_scala() { set -ex make scalapkg USE_BLAS=openblas diff --git a/example/quantization/README.md b/example/quantization/README.md new file mode 100644 index 000000000000..63b65574d3ac --- /dev/null +++ b/example/quantization/README.md @@ -0,0 +1,22 @@ +# Model Quantization with Calibration Examples +This folder contains examples of quantizing a FP32 model with or without calibration and using the calibrated +quantized for inference. Two pre-trained imagenet models are taken as examples for quantization. One is +[Resnet-152](http://data.mxnet.io/models/imagenet/resnet/152-layers/), and the other one is +[Inception with BatchNorm](http://data.mxnet.io/models/imagenet/inception-bn/). The calibration dataset +is the [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) for testing the pre-trained models. + +Here are the details of the four files in this folder. +- `imagenet_gen_qsym.py` This script provides an example of taking FP32 models and calibration dataset to generate +calibrated quantized models. When launched for the first time, the script would download the user-specified model, +either Resnet-152 or Inception, +and calibration dataset into `model` and `data` folders, respectively. The generated quantized models can be found in +the `model` folder. +- `imagenet_inference.py` This script is used for calculating the accuracy of FP32 models or quantized models on the +validation dataset which was downloaded for calibration in `imagenet_gen_qsym.py`. +- `launch_quantize.sh` This is a shell script that generates various quantized models for Resnet-152 and +Inception with BatchNorm with different configurations. Users can copy and paste the command from the script to +the console to run model quantization for a specific configuration. +- `launch_inference.sh` This is a shell script that calculate the accuracies of all the quantized models generated +by invoking `launch_quantize.sh`. + +**NOTE**: This example has only been tested on Linux systems. \ No newline at end of file diff --git a/example/quantization/common b/example/quantization/common new file mode 120000 index 000000000000..cafb9140ab6a --- /dev/null +++ b/example/quantization/common @@ -0,0 +1 @@ +../image-classification/common \ No newline at end of file diff --git a/example/quantization/imagenet_gen_qsym.py b/example/quantization/imagenet_gen_qsym.py new file mode 100644 index 000000000000..045ce62489ad --- /dev/null +++ b/example/quantization/imagenet_gen_qsym.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import os +import logging +from common import modelzoo +import mxnet as mx +from mxnet.contrib.quantization import * + + +def download_calib_dataset(dataset_url, calib_dataset, logger=None): + if logger is not None: + logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset)) + mx.test_utils.download(dataset_url, calib_dataset) + + +def download_model(model_name, logger=None): + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if logger is not None: + logger.info('Downloading model %s... into path %s' % (model_name, model_path)) + return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) + + +def save_symbol(fname, sym, logger=None): + if logger is not None: + logger.info('Saving symbol into file at %s' % fname) + sym.save(fname) + + +def save_params(fname, arg_params, aux_params, logger=None): + if logger is not None: + logger.info('Saving params into file at %s' % fname) + save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) + mx.nd.save(fname, save_dict) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model') + parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'], + help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn') + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--label-name', type=str, default='softmax_label') + parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec', + help='path of the calibration dataset') + parser.add_argument('--image-shape', type=str, default='3,224,224') + parser.add_argument('--data-nthreads', type=int, default=60, + help='number of threads for data decoding') + parser.add_argument('--num-calib-batches', type=int, default=10, + help='number of batches for calibration') + parser.add_argument('--exclude-first-conv', action='store_true', default=True, + help='excluding quantizing the first conv layer since the' + ' number of channels is usually not a multiple of 4 in that layer' + ' which does not satisfy the requirement of cuDNN') + parser.add_argument('--shuffle-dataset', action='store_true', default=True, + help='shuffle the calibration dataset') + parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304, + help='shuffling chunk seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + parser.add_argument('--shuffle-seed', type=int, default=48564309, + help='shuffling seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + parser.add_argument('--calib-mode', type=str, default='entropy', + help='calibration mode used for generating calibration table for the quantized symbol; supports' + ' 1. none: no calibration will be used. The thresholds for quantization will be calculated' + ' on the fly. This will result in inference speed slowdown and loss of accuracy' + ' in general.' + ' 2. naive: simply take min and max values of layer outputs as thresholds for' + ' quantization. In general, the inference accuracy worsens with more examples used in' + ' calibration. It is recommended to use `entropy` mode as it produces more accurate' + ' inference results.' + ' 3. entropy: calculate KL divergence of the fp32 output and quantized output for optimal' + ' thresholds. This mode is expected to produce the best inference accuracy of all three' + ' kinds of quantized models if the calibration dataset is representative enough of the' + ' inference dataset.') + args = parser.parse_args() + + logging.basicConfig() + logger = logging.getLogger('logger') + logger.setLevel(logging.INFO) + + logger.info('shuffle_dataset=%s' % args.shuffle_dataset) + + calib_mode = args.calib_mode + logger.info('calibration mode set to %s' % calib_mode) + + # download calibration dataset + if calib_mode != 'none': + download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset) + + # download model + prefix, epoch = download_model(model_name=args.model, logger=logger) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + + # get batch size + batch_size = args.batch_size + logger.info('batch size = %d for calibration' % batch_size) + + # get number of batches for calibration + num_calib_batches = args.num_calib_batches + if calib_mode != 'none': + logger.info('number of batches = %d for calibration' % num_calib_batches) + + # get number of threads for decoding the dataset + data_nthreads = args.data_nthreads + + # get image shape + image_shape = args.image_shape + + exclude_first_conv = args.exclude_first_conv + excluded_sym_names = [] + if args.model == 'imagenet1k-resnet-152': + rgb_mean = '0,0,0' + calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 + or name.find('sc') != -1 + or name.find('fc') != -1) + if exclude_first_conv: + excluded_sym_names = ['conv0'] + elif args.model == 'imagenet1k-inception-bn': + rgb_mean = '123.68,116.779,103.939' + calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 + or name.find('fc') != -1) + if exclude_first_conv: + excluded_sym_names = ['conv_1'] + else: + raise ValueError('model %s is not supported in this script' % args.model) + + label_name = args.label_name + logger.info('label_name = %s' % label_name) + + data_shape = tuple([int(i) for i in image_shape.split(',')]) + logger.info('Input data shape = %s' % str(data_shape)) + + logger.info('rgb_mean = %s' % rgb_mean) + rgb_mean = [float(i) for i in rgb_mean.split(',')] + mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} + + if calib_mode == 'none': + logger.info('Quantizing FP32 model %s' % args.model) + qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + calib_mode=calib_mode, logger=logger) + sym_name = '%s-symbol.json' % (prefix + '-quantized') + save_symbol(sym_name, qsym, logger) + else: + logger.info('Creating ImageRecordIter for reading calibration dataset') + data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **mean_args) + + cqsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + ctx=mx.gpu(0), excluded_sym_names=excluded_sym_names, + calib_mode=calib_mode, calib_data=data, + num_calib_examples=num_calib_batches * batch_size, + calib_layer=calib_layer, logger=logger) + if calib_mode == 'entropy': + suffix = '-quantized-%dbatches-entropy' % num_calib_batches + elif calib_mode == 'naive': + suffix = '-quantized-%dbatches-naive' % num_calib_batches + else: + raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' + % calib_mode) + sym_name = '%s-symbol.json' % (prefix + suffix) + save_symbol(sym_name, cqsym, logger) + + param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) + save_params(param_name, qarg_params, aux_params, logger) diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py new file mode 100644 index 000000000000..fe3f2661c655 --- /dev/null +++ b/example/quantization/imagenet_inference.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import logging +import os +import time +import mxnet as mx +from mxnet import nd +from mxnet.contrib.quantization import * + + +def download_dataset(dataset_url, dataset_dir, logger=None): + if logger is not None: + logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir)) + mx.test_utils.download(dataset_url, dataset_dir) + + +def load_model(symbol_file, param_file, logger=None): + cur_path = os.path.dirname(os.path.realpath(__file__)) + symbol_file_path = os.path.join(cur_path, symbol_file) + if logger is not None: + logger.info('Loading symbol from file %s' % symbol_file_path) + symbol = mx.sym.load(symbol_file_path) + + param_file_path = os.path.join(cur_path, param_file) + if logger is not None: + logger.info('Loading params from file %s' % param_file_path) + save_dict = nd.load(param_file_path) + arg_params = {} + aux_params = {} + for k, v in save_dict.items(): + tp, name = k.split(':', 1) + if tp == 'arg': + arg_params[name] = v + if tp == 'aux': + aux_params[name] = v + return symbol, arg_params, aux_params + + +def advance_data_iter(data_iter, n): + assert n >= 0 + if n == 0: + return data_iter + has_next_batch = True + while has_next_batch: + try: + data_iter.next() + n -= 1 + if n == 0: + return data_iter + except StopIteration: + has_next_batch = False + + +def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger=None): + metrics = [mx.metric.create('acc'), + mx.metric.create('top_k_accuracy', top_k=5)] + if not isinstance(metrics, list): + metrics = [metrics, ] + mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ]) + mod.bind(for_training=False, + data_shapes=data.provide_data, + label_shapes=data.provide_label) + mod.set_params(arg_params, aux_params) + + tic = time.time() + num = 0 + for batch in data: + mod.forward(batch, is_train=False) + for m in metrics: + mod.update_metric(m, batch.label) + num += batch_size + if max_num_examples is not None and num >= max_num_examples: + break + + speed = num / (time.time() - tic) + + if logger is not None: + logger.info('Finished inference with %d images' % num) + logger.info('Finished with %f images per second', speed) + for m in metrics: + logger.info(m.get()) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Score a model on a dataset') + parser.add_argument('--symbol-file', type=str, required=True, help='symbol file path') + parser.add_argument('--param-file', type=str, required=True, help='param file path') + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--label-name', type=str, default='softmax_label') + parser.add_argument('--dataset', type=str, required=True, help='dataset path') + parser.add_argument('--rgb-mean', type=str, default='0,0,0') + parser.add_argument('--image-shape', type=str, default='3,224,224') + parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding') + parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference') + parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference') + parser.add_argument('--shuffle-dataset', action='store_true', default=True, + help='shuffle the calibration dataset') + parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304, + help='shuffling chunk seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + parser.add_argument('--shuffle-seed', type=int, default=48564309, + help='shuffling seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + + args = parser.parse_args() + + logging.basicConfig() + logger = logging.getLogger('logger') + logger.setLevel(logging.INFO) + + symbol_file = args.symbol_file + param_file = args.param_file + data_nthreads = args.data_nthreads + + batch_size = args.batch_size + logger.info('batch size = %d for inference' % batch_size) + + rgb_mean = args.rgb_mean + logger.info('rgb_mean = %s' % rgb_mean) + rgb_mean = [float(i) for i in rgb_mean.split(',')] + mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} + + label_name = args.label_name + logger.info('label_name = %s' % label_name) + + image_shape = args.image_shape + data_shape = tuple([int(i) for i in image_shape.split(',')]) + logger.info('Input data shape = %s' % str(data_shape)) + + dataset = args.dataset + download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset) + logger.info('Dataset for inference: %s' % dataset) + + # creating data iterator + data = mx.io.ImageRecordIter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=True, + shuffle_chunk_seed=3982304, + seed=48564309, + **mean_args) + + # loading model + sym, arg_params, aux_params = load_model(symbol_file, param_file, logger) + + # make sure that fp32 inference works on the same images as calibrated quantized model + logger.info('Skipping the first %d batches' % args.num_skipped_batches) + data = advance_data_iter(data, args.num_skipped_batches) + + num_inference_images = args.num_inference_batches * batch_size + logger.info('Running model %s for inference' % symbol_file) + score(sym, arg_params, aux_params, data, [mx.gpu(0)], label_name, + max_num_examples=num_inference_images, logger=logger) diff --git a/example/quantization/launch_inference.sh b/example/quantization/launch_inference.sh new file mode 100755 index 000000000000..8c839ba0f611 --- /dev/null +++ b/example/quantization/launch_inference.sh @@ -0,0 +1,45 @@ +#!/bin/sh + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex + +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-symbol.json --param-file=./model/imagenet1k-resnet-152-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-10batches-naive-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-50batches-naive-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-5batches-entropy-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-10batches-entropy-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-50batches-entropy-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --rgb-mean=0,0,0 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + + +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-symbol.json --param-file=./model/imagenet1k-inception-bn-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-10batches-naive-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-50batches-naive-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-5batches-entropy-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-10batches-entropy-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec +python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-50batches-entropy-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --num-inference-batches=500 --dataset=./data/val_256_q90.rec diff --git a/example/quantization/launch_quantize.sh b/example/quantization/launch_quantize.sh new file mode 100755 index 000000000000..9aa4bee4bff1 --- /dev/null +++ b/example/quantization/launch_quantize.sh @@ -0,0 +1,41 @@ +#!/bin/sh + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex + +python imagenet_gen_qsym.py --model=imagenet1k-resnet-152 --calib-mode=none + +python imagenet_gen_qsym.py --model=imagenet1k-resnet-152 --calib-dataset=./data/val_256_q90.rec --num-calib-batches=5 --calib-mode=naive +python imagenet_gen_qsym.py --model=imagenet1k-resnet-152 --calib-dataset=./data/val_256_q90.rec --num-calib-batches=10 --calib-mode=naive +python imagenet_gen_qsym.py --model=imagenet1k-resnet-152 --calib-dataset=./data/val_256_q90.rec --num-calib-batches=50 --calib-mode=naive + +python imagenet_gen_qsym.py --model=imagenet1k-resnet-152 --calib-dataset=./data/val_256_q90.rec --num-calib-batches=5 --calib-mode=entropy +python imagenet_gen_qsym.py --model=imagenet1k-resnet-152 --calib-dataset=./data/val_256_q90.rec --num-calib-batches=10 --calib-mode=entropy +python imagenet_gen_qsym.py --model=imagenet1k-resnet-152 --calib-dataset=./data/val_256_q90.rec --num-calib-batches=50 --calib-mode=entropy + + +python imagenet_gen_qsym.py --model=imagenet1k-inception-bn --calib-mode=none + +python imagenet_gen_qsym.py --model=imagenet1k-inception-bn --calib-dataset=./data/val_256_q90.rec --num-calib-batches=5 --calib-mode=naive +python imagenet_gen_qsym.py --model=imagenet1k-inception-bn --calib-dataset=./data/val_256_q90.rec --num-calib-batches=10 --calib-mode=naive +python imagenet_gen_qsym.py --model=imagenet1k-inception-bn --calib-dataset=./data/val_256_q90.rec --num-calib-batches=50 --calib-mode=naive + +python imagenet_gen_qsym.py --model=imagenet1k-inception-bn --calib-dataset=./data/val_256_q90.rec --num-calib-batches=5 --calib-mode=entropy +python imagenet_gen_qsym.py --model=imagenet1k-inception-bn --calib-dataset=./data/val_256_q90.rec --num-calib-batches=10 --calib-mode=entropy +python imagenet_gen_qsym.py --model=imagenet1k-inception-bn --calib-dataset=./data/val_256_q90.rec --num-calib-batches=50 --calib-mode=entropy diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index e85afe522f0a..ede137e89b7a 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1386,8 +1386,37 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, const int **aux_type_data, int *complete); - - +/*! + * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 + * \param sym_handle symbol to be converted + * \param ret_sym_handle quantized symbol result + * \param num_excluded_symbols number of layers excluded from being quantized in the input symbol + * \param excluded_symbols array of symbols to be excluded from being quantized + * \param num_offline number of parameters that are quantized offline + * \param offline_params array of c strings representing the names of params quantized offline + */ +MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, + SymbolHandle *ret_sym_handle, + const mx_uint num_excluded_symbols, + const SymbolHandle *excluded_symbols, + const mx_uint num_offline, + const char **offline_params); + +/*! + * \brief Set calibration table to node attributes in the sym + * \param sym_handle symbol whose node attributes are to be set by calibration table + * \param num_layers number of layers in the calibration table + * \param layer names stored as keys in the calibration table + * \param low_quantiles low quantiles of layers stored in the calibration table + * \param high_quantiles high quantiles of layers stored in the calibration table + * \param ret_sym_handle returned symbol + */ +MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, + const mx_uint num_layers, + const char** layer_names, + const float* low_quantiles, + const float* high_quantiles, + SymbolHandle* ret_sym_handle); //-------------------------------------------- // Part 4: Executor interface diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index fb41d3960995..820eb1aa7b02 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -239,7 +239,6 @@ using FCompute = std::function& outputs)>; /*! * \brief Resiger an NDArray compute function for simple stateless forward only operator - * * \note Register under "FComputeEx" and "FComputeEx" * Dispatched only when inferred dispatch_mode is FDispatchComputeEx */ @@ -261,6 +260,20 @@ using FInferStorageType = std::function* in_attrs, std::vector* out_attrs)>; +/*! + * \brief Register a quantized node creation function based on the attrs of the node + * \note Register under "FQuantizedOp" for non-quantized operators + */ +using FQuantizedOp = std::function; + +/*! + * \brief Register a function to determine if the output of a quantized operator + * needs to be requantized. This is usually used for the operators + * taking int8 data types while accumulating in int32, e.g. quantized_conv. + * \note Register under "FNeedRequantize" for non-quantized operators + */ +using FNeedRequantize = std::function; + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py index 63cd8ce26649..fbfd3469678b 100644 --- a/python/mxnet/contrib/__init__.py +++ b/python/mxnet/contrib/__init__.py @@ -30,3 +30,5 @@ from . import text from . import onnx from . import io +from . import quantization +from . import quantization as quant diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py new file mode 100644 index 000000000000..c9c58a9c9ba4 --- /dev/null +++ b/python/mxnet/contrib/quantization.py @@ -0,0 +1,520 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Quantization module for generating quantized (INT8) models from FP32 models.""" + +from __future__ import absolute_import + +try: + from scipy import stats +except ImportError: + stats = None + +import ctypes +import logging +import os +import numpy as np +from ..base import _LIB, check_call, py_str +from ..base import c_array, c_str, mx_uint, c_str_array +from ..base import NDArrayHandle, SymbolHandle +from ..symbol import Symbol +from ..symbol import load as sym_load +from .. import ndarray +from ..ndarray import load as nd_load +from ..ndarray import NDArray +from ..io import DataIter +from ..context import cpu, Context +from ..module import Module + + +def _quantize_params(qsym, params): + """Given a quantized symbol and a dict of params that have not been quantized, + generate quantized params. Currently only supports quantizing the arg_params + with names of `weight` or `bias`, not aux_params. If `qsym` contains symbols + that are excluded from being quantized, their corresponding params will + not be quantized, but saved together with quantized params of the symbols that + have been quantized. + + Parameters + ---------- + qsym : Symbol + Quantized symbol from FP32 symbol. + params : dict of str->NDArray + """ + inputs_name = qsym.list_arguments() + quantized_params = {} + for name in inputs_name: + if name.endswith(('weight_quantize', 'bias_quantize')): + original_name = name[:-len('_quantize')] + param = params[original_name] + val, vmin, vmax = ndarray.contrib.quantize(data=param, + min_range=ndarray.min(param), + max_range=ndarray.max(param), + out_type='int8') + quantized_params[name] = val + quantized_params[name+'_min'] = vmin + quantized_params[name+'_max'] = vmax + elif name in params: + quantized_params[name] = params[name] + return quantized_params + + +def _quantize_symbol(sym, excluded_symbols=None, offline_params=None): + """Given a symbol object representing a neural network of data type FP32, + quantize it into a INT8 network. + + Parameters + ---------- + sym : Symbol + FP32 neural network symbol. + excluded_symbols : list of symbols + Nodes in the network that users do not want to replace with a symbol of INT8 data type. + offline_params : list of strs + Names of the parameters that users want to quantize offline. It's always recommended to + quantize parameters offline so that quantizing parameters during the inference can be + avoided. + """ + num_excluded_symbols = 0 + excluded_handles = [] + if excluded_symbols is not None: + assert isinstance(excluded_symbols, list) + num_excluded_symbols = len(excluded_symbols) + for s in excluded_symbols: + excluded_handles.append(s.handle) + + num_offline = 0 + offline = [] + if offline_params is not None: + num_offline = len(offline_params) + for k in offline_params: + offline.append(c_str(k)) + + out = SymbolHandle() + check_call(_LIB.MXQuantizeSymbol(sym.handle, + ctypes.byref(out), + mx_uint(num_excluded_symbols), + c_array(SymbolHandle, excluded_handles), + mx_uint(num_offline), + c_array(ctypes.c_char_p, offline))) + return Symbol(out) + + +class _LayerOutputCollector(object): + """Saves layer output NDArray in a dict with layer names as keys and lists of NDArrays as + values. The collected NDArrays will be used for calculating the optimal thresholds for + quantization using KL divergence. + """ + def __init__(self, include_layer=None, logger=None): + self.nd_dict = {} + self.include_layer = include_layer + self.logger = logger + + def collect(self, name, arr): + """Callback function for collecting layer output NDArrays.""" + name = py_str(name) + if self.include_layer is not None and not self.include_layer(name): + return + handle = ctypes.cast(arr, NDArrayHandle) + arr = NDArray(handle, writable=False).copyto(cpu()) + if self.logger is not None: + self.logger.info("Collecting layer %s output of shape %s" % (name, arr.shape)) + if name in self.nd_dict: + self.nd_dict[name].append(arr) + else: + self.nd_dict[name] = [arr] + + +class _LayerOutputMinMaxCollector(object): + """Saves layer output min and max values in a dict with layer names as keys. + The collected min and max values will be directly used as thresholds for quantization. + """ + def __init__(self, include_layer=None, logger=None): + self.min_max_dict = {} + self.include_layer = include_layer + self.logger = logger + + def collect(self, name, arr): + """Callback function for collecting min and max values from an NDArray.""" + name = py_str(name) + if self.include_layer is not None and not self.include_layer(name): + return + handle = ctypes.cast(arr, NDArrayHandle) + arr = NDArray(handle, writable=False) + min_range = ndarray.min(arr).asscalar() + max_range = ndarray.max(arr).asscalar() + if name in self.min_max_dict: + cur_min_max = self.min_max_dict[name] + self.min_max_dict[name] = (min(cur_min_max[0], min_range), + max(cur_min_max[1], max_range)) + else: + self.min_max_dict[name] = (min_range, max_range) + if self.logger is not None: + self.logger.info("Collecting layer %s output min_range=%f, max_range=%f" + % (name, min_range, max_range)) + + +def _calibrate_quantized_sym(qsym, th_dict): + """Given a dictionary containing the thresholds for quantizing the layers, + set the thresholds into the quantized symbol as the params of requantize operators. + """ + if th_dict is None or len(th_dict) == 0: + return qsym + num_layer_outputs = len(th_dict) + layer_output_names = [] + min_vals = [] + max_vals = [] + for k, v in th_dict.items(): + layer_output_names.append(k) + min_vals.append(v[0]) + max_vals.append(v[1]) + + calibrated_sym = SymbolHandle() + check_call(_LIB.MXSetCalibTableToQuantizedSymbol(qsym.handle, + mx_uint(num_layer_outputs), + c_str_array(layer_output_names), + c_array(ctypes.c_float, min_vals), + c_array(ctypes.c_float, max_vals), + ctypes.byref(calibrated_sym))) + return Symbol(calibrated_sym) + + +def _collect_layer_statistics(mod, data, collector, max_num_examples=None, logger=None): + if not isinstance(data, DataIter): + raise ValueError('Only supports data as a type of DataIter, while received type %s' + % str(type(data))) + mod._exec_group.execs[0].set_monitor_callback(collector.collect) + num_batches = 0 + num_examples = 0 + for batch in data: + mod.forward(data_batch=batch, is_train=False) + num_batches += 1 + num_examples += data.batch_size + if max_num_examples is not None and num_examples >= max_num_examples: + break + if logger is not None: + logger.info("Collected statistics from %d batches with batch_size=%d" + % (num_batches, data.batch_size)) + return num_examples + + +def _collect_layer_output_min_max(mod, data, include_layer=None, + max_num_examples=None, logger=None): + """Collect min and max values from layer outputs and save them in + a dictionary mapped by layer names. + """ + collector = _LayerOutputMinMaxCollector(include_layer=include_layer, logger=logger) + num_examples = _collect_layer_statistics(mod, data, collector, max_num_examples, logger) + return collector.min_max_dict, num_examples + + +def _collect_layer_outputs(mod, data, include_layer=None, max_num_examples=None, logger=None): + """Collect layer outputs and save them in a dictionary mapped by layer names.""" + collector = _LayerOutputCollector(include_layer=include_layer, logger=logger) + num_examples = _collect_layer_statistics(mod, data, collector, max_num_examples, logger) + return collector.nd_dict, num_examples + + +def _smooth_distribution(p, eps=0.0001): + """Given a discrete distribution (may have not been normalized to 1), + smooth it by replacing zeros with eps multiplied by a scaling factor and taking the + corresponding amount off the non-zero values. + Ref: http://web.engr.illinois.edu/~hanj/cs412/bk3/KL-divergence.pdf + """ + is_zeros = (p == 0).astype(np.float32) + is_nonzeros = (p != 0).astype(np.float32) + n_zeros = is_zeros.sum() + n_nonzeros = p.size - n_zeros + eps1 = eps * float(n_zeros) / float(n_nonzeros) + assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) + hist = p.astype(np.float32) + hist += eps * is_zeros + (-eps1) * is_nonzeros + assert (hist <= 0).sum() == 0 + return hist + + +# pylint: disable=line-too-long +def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): + """Given a dataset, find the optimal threshold for quantizing it. + Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + """ + if isinstance(arr, NDArray): + arr = arr.asnumpy() + elif isinstance(arr, list): + assert len(arr) != 0 + for i, nd in enumerate(arr): + if isinstance(nd, NDArray): + arr[i] = nd.asnumpy() + elif not isinstance(nd, np.ndarray): + raise TypeError('get_optimal_threshold only supports input type of NDArray,' + ' list of np.ndarrays or NDArrays, and np.ndarray,' + ' while received type=%s' % (str(type(nd)))) + arr = np.concatenate(arr) + elif not isinstance(arr, np.ndarray): + raise TypeError('get_optimal_threshold only supports input type of NDArray,' + ' list of NDArrays and np.ndarray,' + ' while received type=%s' % (str(type(arr)))) + min_val = np.min(arr) + max_val = np.max(arr) + th = max(abs(min_val), abs(max_val)) + + hist, hist_edeges = np.histogram(arr, bins=num_bins, range=(-th, th)) + zero_bin_idx = num_bins // 2 + num_half_quantized_bins = num_quantized_bins // 2 + assert np.allclose(hist_edeges[zero_bin_idx] + hist_edeges[zero_bin_idx + 1], + 0, rtol=1e-5, atol=1e-7) + + thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) + divergence = np.zeros_like(thresholds) + quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) + # i means the number of bins on half axis excluding the zero bin + for i in range(num_quantized_bins // 2, + num_bins // 2 + 1): + p_bin_idx_start = zero_bin_idx - i + p_bin_idx_stop = zero_bin_idx + i + 1 + thresholds[i - num_half_quantized_bins] = hist_edeges[p_bin_idx_stop] + # sliced_nd_hist is used to generate candidate distribution q + sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] + + # generate reference distribution p + p = sliced_nd_hist.copy() + assert p.size % 2 == 1 + assert p.size >= num_quantized_bins + # put left outlier count in p[0] + left_outlier_count = np.sum(hist[0:p_bin_idx_start]) + p[0] += left_outlier_count + # put right outlier count in p[-1] + right_outlier_count = np.sum(hist[p_bin_idx_stop:]) + p[-1] += right_outlier_count + # is_nonzeros[k] indicates whether hist[k] is nonzero + is_nonzeros = (sliced_nd_hist != 0).astype(np.int32) + + # calculate how many bins should be merged to generate quantized distribution q + num_merged_bins = p.size // num_quantized_bins + # merge hist into num_quantized_bins bins + for j in range(num_quantized_bins): + start = j * num_merged_bins + stop = start + num_merged_bins + quantized_bins[j] = sliced_nd_hist[start:stop].sum() + quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() + # expand quantized_bins into p.size bins + q = np.zeros(p.size, dtype=np.float32) + for j in range(num_quantized_bins): + start = j * num_merged_bins + if j == num_quantized_bins - 1: + stop = -1 + else: + stop = start + num_merged_bins + norm = is_nonzeros[start:stop].sum() + if norm != 0: + q[start:stop] = float(quantized_bins[j]) / float(norm) + q[sliced_nd_hist == 0] = 0 + p = _smooth_distribution(p) + q = _smooth_distribution(q) + divergence[i - num_half_quantized_bins] = stats.entropy(p, q) + quantized_bins[:] = 0 + + min_divergence_idx = np.argmin(divergence) + min_divergence = divergence[min_divergence_idx] + opt_th = thresholds[min_divergence_idx] + return min_val, max_val, min_divergence, opt_th +# pylint: enable=line-too-long + + +def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logger=None): + """Given a ndarray dict, find the optimal threshold for quantizing each value of the key.""" + if stats is None: + raise ImportError('scipy.stats is required for running entropy mode of calculating' + ' the optimal thresholds for quantizing FP32 ndarrays into int8.' + ' Please check if the scipy python bindings are installed.') + assert isinstance(nd_dict, dict) + if logger is not None: + logger.info('Calculating optimal thresholds for quantization using KL divergence' + ' with num_bins=%d and num_quantized_bins=%d' % (num_bins, num_quantized_bins)) + th_dict = {} + # copy nd_dict keys since the keys() only returns a view in python3 + layer_names = list(nd_dict.keys()) + for name in layer_names: + assert name in nd_dict + min_val, max_val, min_divergence, opt_th =\ + _get_optimal_threshold(nd_dict[name], num_bins=num_bins, + num_quantized_bins=num_quantized_bins) + del nd_dict[name] # release the memory of ndarray + th_dict[name] = (-opt_th, opt_th) + if logger is not None: + logger.info('layer=%s, min_val=%f, max_val=%f, min_divergence=%f, optimal_threshold=%f' + % (name, min_val, max_val, min_divergence, opt_th)) + return th_dict + + +def _load_sym(sym, logger=logging): + """Given a str as a path the symbol .json file or a symbol, returns a Symbol object.""" + if isinstance(sym, str): # sym is a symbol file path + cur_path = os.path.dirname(os.path.realpath(__file__)) + symbol_file_path = os.path.join(cur_path, sym) + logger.info('Loading symbol from file %s' % symbol_file_path) + return sym_load(symbol_file_path) + elif isinstance(sym, Symbol): + return sym + else: + raise ValueError('_load_sym only accepts Symbol or path to the symbol file,' + ' while received type %s' % str(type(sym))) + + +def _load_params(params, logger=logging): + """Given a str as a path to the .params file or a pair of params, + returns two dictionaries representing arg_params and aux_params. + """ + if isinstance(params, str): + cur_path = os.path.dirname(os.path.realpath(__file__)) + param_file_path = os.path.join(cur_path, params) + logger.info('Loading params from file %s' % param_file_path) + save_dict = nd_load(param_file_path) + arg_params = {} + aux_params = {} + for k, v in save_dict.items(): + tp, name = k.split(':', 1) + if tp == 'arg': + arg_params[name] = v + if tp == 'aux': + aux_params[name] = v + return arg_params, aux_params + elif isinstance(params, (tuple, list)) and len(params) == 2: + return params[0], params[1] + else: + raise ValueError('Unsupported params provided. Must be either a path to the param file or' + ' a pair of dictionaries representing arg_params and aux_params') + + +def quantize_model(sym, arg_params, aux_params, + data_names=('data',), label_names=('softmax_label',), + ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', + calib_data=None, num_calib_examples=None, calib_layer=None, logger=logging): + """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + The quantization implementation adopts the TensorFlow's approach: + https://www.tensorflow.org/performance/quantization. + The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + and adapts the method to MXNet. + + Parameters + ---------- + sym : str or Symbol + Defines the structure of a neural network for FP32 data types. + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + data_names : a list of strs + Data names required for creating a Module object to run forward propagation on the + calibration dataset. + label_names : a list of strs + Label names required for creating a Module object to run forward propagation on the + calibration dataset. + ctx : Context + Defines the device that users want to run forward propagation on the calibration + dataset for collecting layer output statistics. Currently, only supports single context. + excluded_sym_names : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + calib_data : DataIter + A data iterator initialized by the calibration dataset. + num_calib_examples : int or None + The maximum number of examples that user would like to use for calibration. If not provided, + the whole calibration dataset will be used. + calib_layer : function + Given a layer's output name in string, return True or False for deciding whether to + calibrate this layer. If yes, the statistics of the layer's output will be collected; + otherwise, no information of the layer's output will be collected. If not provided, + all the layers' outputs that need requantization will be collected. + logger : Object + A logging object for printing information during the process of quantization. + + Returns + ------- + tuple + A tuple of quantized symbol, quantized arg_params, and aux_params. + ------- + """ + if excluded_sym_names is None: + excluded_sym_names = [] + if not isinstance(excluded_sym_names, list): + raise ValueError('excluded_sym_names must be a list of strings representing' + ' the names of the symbols that will not be quantized,' + ' while received type %s' % str(type(excluded_sym_names))) + excluded_syms = [] + if excluded_sym_names is not None: + for sym_name in excluded_sym_names: + nodes = sym.get_internals() + idx = nodes.list_outputs().index(sym_name + '_output') + excluded_syms.append(nodes[idx]) + logger.info('Quantizing symbol') + qsym = _quantize_symbol(sym, excluded_symbols=excluded_syms, + offline_params=list(arg_params.keys())) + + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params) + + if calib_mode is not None and calib_mode != 'none': + if not isinstance(ctx, Context): + raise ValueError('currently only supports single ctx, while received %s' % str(ctx)) + if calib_data is None: + raise ValueError('calib_data must be provided when calib_mode=%s' % calib_mode) + if not isinstance(calib_data, DataIter): + raise ValueError('calib_data must be of DataIter type when calib_mode=%s,' + ' while received type %s' % (calib_mode, str(type(calib_data)))) + if calib_layer is None: + calib_layer = lambda name: name.endswith('_output') + + mod = Module(symbol=sym, data_names=data_names, label_names=label_names, context=ctx) + if len(calib_data.provide_label) > 0: + mod.bind(for_training=False, data_shapes=calib_data.provide_data, + label_shapes=calib_data.provide_label) + else: + mod.bind(for_training=False, data_shapes=calib_data.provide_data) + mod.set_params(arg_params, aux_params) + if calib_mode == 'entropy': + nd_dict, num_examples = _collect_layer_outputs(mod, calib_data, + include_layer=calib_layer, + max_num_examples=num_calib_examples, + logger=logger) + logger.info('Collected layer outputs from FP32 model using %d examples' % num_examples) + logger.info('Calculating optimal thresholds for quantization') + th_dict = _get_optimal_thresholds(nd_dict, logger=logger) + elif calib_mode == 'naive': + th_dict, num_examples = _collect_layer_output_min_max( + mod, calib_data, include_layer=calib_layer, max_num_examples=num_calib_examples, + logger=logger) + logger.info('Collected layer output min/max values from FP32 model using %d examples' + % num_examples) + else: + raise ValueError('unknown calibration mode %s received,' + ' expected `none`, `naive`, or `entropy`' % calib_mode) + logger.info('Calibrating quantized symbol') + qsym = _calibrate_quantized_sym(qsym, th_dict) + + return qsym, qarg_params, aux_params diff --git a/python/mxnet/model.py b/python/mxnet/model.py index c66b17cab34d..26e885a1cd8d 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -28,7 +28,7 @@ import numpy as np from . import io -from . import nd +from . import ndarray as nd from . import symbol as sym from . import optimizer as opt from . import metric @@ -592,15 +592,17 @@ def __setstate__(self, state): def _init_predictor(self, input_shapes, type_dict=None): """Initialize the predictor module for running prediction.""" + shapes = {name: self.arg_params[name].shape for name in self.arg_params} + shapes.update(dict(input_shapes)) if self._pred_exec is not None: - arg_shapes, _, _ = self.symbol.infer_shape(**dict(input_shapes)) + arg_shapes, _, _ = self.symbol.infer_shape(**shapes) assert arg_shapes is not None, "Incomplete input shapes" pred_shapes = [x.shape for x in self._pred_exec.arg_arrays] if arg_shapes == pred_shapes: return # for now only use the first device pred_exec = self.symbol.simple_bind( - self.ctx[0], grad_req='null', type_dict=type_dict, **dict(input_shapes)) + self.ctx[0], grad_req='null', type_dict=type_dict, **shapes) pred_exec.copy_params_from(self.arg_params, self.aux_params) _check_arguments(self.symbol) diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 3668af060069..4666b6adf0c3 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -571,3 +571,54 @@ int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHand LOG(FATAL) << "not implemented"; API_END(); } + +int MXQuantizeSymbol(SymbolHandle sym_handle, + SymbolHandle *ret_sym_handle, + const mx_uint num_excluded_symbols, + const SymbolHandle *excluded_symbols, + const mx_uint num_offline, + const char **offline_params) { + nnvm::Symbol *s = new nnvm::Symbol(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(sym_handle); + nnvm::Graph g = Symbol2Graph(*sym); + std::unordered_set excluded_nodes; + for (size_t i = 0; i < num_excluded_symbols; ++i) { + nnvm::Symbol* sym = static_cast(excluded_symbols[i]); + for (const auto& e : sym->outputs) { + excluded_nodes.emplace(e.node); + } + } + g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_nodes)); + std::unordered_set offline; + for (size_t i = 0; i < num_offline; ++i) { + offline.emplace(offline_params[i]); + } + g.attrs["offline_params"] = std::make_shared(std::move(offline)); + g = ApplyPass(std::move(g), "QuantizeGraph"); + s->outputs = g.outputs; + *ret_sym_handle = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, + const mx_uint num_layers, + const char** layer_names, + const float* min_ranges, + const float* max_ranges, + SymbolHandle* ret_qsym_handle) { + nnvm::Symbol* s = new nnvm::Symbol(); + API_BEGIN(); + nnvm::Symbol* sym = static_cast(qsym_handle); + nnvm::Graph g = Symbol2Graph(*sym); + const std::string prefix = "quantized_"; + std::unordered_map> calib_table; + for (size_t i = 0; i < num_layers; ++i) { + calib_table.emplace(prefix+layer_names[i], std::make_pair(min_ranges[i], max_ranges[i])); + } + g.attrs["calib_table"] = std::make_shared(std::move(calib_table)); + g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph"); + s->outputs = g.outputs; + *ret_qsym_handle = s; + API_END_HANDLE_ERROR(delete s); +} diff --git a/src/io/inst_vector.h b/src/io/inst_vector.h index 2682b94b4fae..f06a4e4aabe9 100644 --- a/src/io/inst_vector.h +++ b/src/io/inst_vector.h @@ -29,6 +29,7 @@ #include #include +#include #include #include #include diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index c98a010774d7..5632d73c2611 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 951063fb4b2f..79f2e800ee14 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -468,6 +468,10 @@ There are other options to tune the performance. else return std::vector{"data", "weight", "bias"}; }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) .set_attr("FInferShape", ConvolutionShape) .set_attr("FInferType", ConvolutionType) .set_attr("FInferStorageType", ConvStorageType) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 75d594ffd91d..475b63625166 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -267,6 +267,10 @@ This could be used for model inference with `row_sparse` weights trained with `S return std::vector{"data", "weight"}; } }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) #if MXNET_USE_MKLDNN == 1 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index ff429df9d7b4..18f088f758e4 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -103,7 +103,8 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, size_t workspace_size = 0; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - workspace_size = broadcast::ReduceWorkspaceSize(s, mean_data, req[0], in_data); + workspace_size = + broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0], in_data.shape_); }); }); workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); @@ -202,16 +203,14 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize( - s, ograd.reshape(red_src_shape), kAddTo, - mean.reshape(red_dst_shape))); + broadcast::ReduceWorkspaceSize(s, red_src_shape, + kAddTo, red_dst_shape)); }); BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize( - s, ograd.reshape(red_exclude_src_shape), kAddTo, - gamma.reshape(red_exclude_dst_shape))); + broadcast::ReduceWorkspaceSize(s, red_exclude_src_shape, kAddTo, + red_exclude_dst_shape)); }); }); workspace = ctx.requested[0].get_space_typed( diff --git a/src/operator/contrib/dequantize-inl.h b/src/operator/quantization/dequantize-inl.h similarity index 52% rename from src/operator/contrib/dequantize-inl.h rename to src/operator/quantization/dequantize-inl.h index 8f24a8fd7b5c..799e13665664 100644 --- a/src/operator/contrib/dequantize-inl.h +++ b/src/operator/quantization/dequantize-inl.h @@ -22,8 +22,8 @@ * \file dequantize-inl.h * \brief Implementation of dequantize operation */ -#ifndef MXNET_OPERATOR_CONTRIB_DEQUANTIZE_INL_H_ -#define MXNET_OPERATOR_CONTRIB_DEQUANTIZE_INL_H_ +#ifndef MXNET_OPERATOR_QUANTIZATION_DEQUANTIZE_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_DEQUANTIZE_INL_H_ #include #include @@ -31,6 +31,7 @@ #include "../elemwise_op_common.h" #include "../mshadow_op.h" #include "../mxnet_op.h" +#include "./quantization_utils.h" namespace mxnet { namespace op { @@ -40,18 +41,30 @@ struct DequantizeParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(DequantizeParam) { DMLC_DECLARE_FIELD(out_type) .add_enum("float32", mshadow::kFloat32) + .set_default(mshadow::kFloat32) .describe("Output data type."); } }; -struct dequantize { +// dequantize unsigned int8 to float32 +struct dequantize_unsigned { template MSHADOW_XINLINE static void Map(int i, DstDType *out, const SrcDType *in, - float *imin_range, float *imax_range, - double imin_limit, double imax_limit, - float half_range) { - float scale = (*imax_range - *imin_range) / (imax_limit - imin_limit); - out[i] = static_cast((in[i] + half_range) * scale + *imin_range); + const float *imin_range, const float *imax_range, + const float imin_limit, const float imax_limit) { + const float scale = (*imax_range - *imin_range) / (imax_limit - imin_limit); + out[i] = static_cast(in[i] * scale + *imin_range); + } +}; + +// keep zero-center +struct dequantize_zero_centered { + template + MSHADOW_XINLINE static void Map(int i, DstDType *out, const SrcDType *in, + const float *imin_range, const float *imax_range, + const float quantized_range) { + const float real_range = MaxAbs(*imax_range, *imin_range); + out[i] = in[i] * (real_range / quantized_range); } }; @@ -63,20 +76,20 @@ void DequantizeCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; Stream *s = ctx.get_stream(); - - // for now, only supports dequantize from float to uint8 - typedef float DstDType; - typedef uint8_t SrcDType; - double min_limit = static_cast(std::numeric_limits::min()); - double max_limit = static_cast(std::numeric_limits::max()); - float half_range = !std::is_signed::value - ? 0.0f - : (max_limit - min_limit + 1) / 2.0; - - Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), - min_limit, max_limit, half_range); + if (inputs[0].type_flag_ == mshadow::kUint8) { + Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), + MinValue(), MaxValue()); + } else if (inputs[0].type_flag_ == mshadow::kInt8) { + Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), + MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "dequantize op only supports input type int8 or uint8"; + } } inline bool DequantizeShape(const nnvm::NodeAttrs& attrs, @@ -85,30 +98,27 @@ inline bool DequantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 1U); - CHECK(!shape_is_none(in_attrs->at(0))); for (size_t i = 1; i < 3; ++i) { - CHECK(shape_is_scalar(in_attrs->at(i))) << in_attrs->at(i); + SHAPE_ASSIGN_CHECK(*in_attrs, i, TShape({1})); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - return true; + return !shape_is_none(out_attrs->at(0)); } inline bool DequantizeType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 1U); - CHECK_EQ((*in_attrs)[0], mshadow::kUint8) - << "`dequantize` only supports uint8 input for now"; - CHECK_EQ((*in_attrs)[1], mshadow::kFloat32) - << "the second input of `dequantize` should be a tensor with type of float"; - CHECK_EQ((*in_attrs)[2], mshadow::kFloat32) - << "the third input of `dequantize` should be a tensor with type of float"; + CHECK(in_attrs->at(0) == mshadow::kUint8 || in_attrs->at(0) == mshadow::kInt8) + << "the input data type of dequantize op must be provided, either uint8 or int8"; + TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32); TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); return (*in_attrs)[0] != -1; } } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_CONTRIB_DEQUANTIZE_INL_H_ +#endif // MXNET_OPERATOR_QUANTIZATION_DEQUANTIZE_INL_H_ diff --git a/src/operator/contrib/dequantize.cc b/src/operator/quantization/dequantize.cc similarity index 71% rename from src/operator/contrib/dequantize.cc rename to src/operator/quantization/dequantize.cc index 7814a157719c..92b808dd460d 100644 --- a/src/operator/contrib/dequantize.cc +++ b/src/operator/quantization/dequantize.cc @@ -30,14 +30,20 @@ DMLC_REGISTER_PARAMETER(DequantizeParam); NNVM_REGISTER_OP(_contrib_dequantize) .describe(R"code(Dequantize the input tensor into a float tensor. -[min_range, max_range] are scalar floats that spcify the range for +min_range and max_range are scalar floats that specify the range for the output data. -Each value of the tensor will undergo the following: +When input data type is `uint8`, the output is calculated using the following equation: -`out[i] = min_range + (in[i] * (max_range - min_range) / range(INPUT_TYPE))` +`out[i] = in[i] * (max_range - min_range) / 255.0`, -here `range(T) = numeric_limits::max() - numeric_limits::min()` +When input data type is `int8`, the output is calculate using the following equation +by keep zero centered for the quantized value: + +`out[i] = in[i] * MaxAbs(min_range, max_range) / 127.0`, + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training. )code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(3) @@ -45,12 +51,11 @@ here `range(T) = numeric_limits::max() - numeric_limits::min()` .set_attr("FInferShape", DequantizeShape) .set_attr("FInferType", DequantizeType) .set_attr("FCompute", DequantizeCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_dequantize"}) -.add_argument("input", "NDArray-or-Symbol", "A ndarray/symbol of type `uint8`") +.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `uint8`") .add_argument("min_range", "NDArray-or-Symbol", "The minimum scalar value " - "possibly produced for the input") + "possibly produced for the input in float32") .add_argument("max_range", "NDArray-or-Symbol", "The maximum scalar value " - "possibly produced for the input") + "possibly produced for the input in float32") .add_arguments(DequantizeParam::__FIELDS__()); } // namespace op diff --git a/src/operator/contrib/dequantize.cu b/src/operator/quantization/dequantize.cu similarity index 100% rename from src/operator/contrib/dequantize.cu rename to src/operator/quantization/dequantize.cu diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h new file mode 100644 index 000000000000..5b096ac0057a --- /dev/null +++ b/src/operator/quantization/quantization_utils.h @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantization_utils-inl.h + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZATION_UTILS_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZATION_UTILS_H_ + +#include +#include +#include "../mxnet_op.h" + +namespace mxnet { +namespace op { + + +template +MSHADOW_XINLINE int Sign(T val) { + return (val > T(0)) - (val < T(0)); +} + +template +MSHADOW_XINLINE T Abs(T a) { +#ifdef __CUDACC__ + return ::abs(a); +#else + return std::abs(a); +#endif +} + +template +MSHADOW_XINLINE T Max(T a, T b) { +#ifdef __CUDACC__ + return ::max(a, b); +#else + return std::max(a, b); +#endif +} + +template +MSHADOW_XINLINE T Min(T a, T b) { +#ifdef __CUDACC__ + return ::min(a, b); +#else + return std::min(a, b); +#endif +} + +template +MSHADOW_XINLINE float MaxAbs(T a, T b) { + return Max(Abs(static_cast(a)), Abs(static_cast(b))); +} + +template +MSHADOW_XINLINE float MinAbs(T a, T b) { + return Min(Abs(static_cast(a)), Abs(static_cast(b))); +} + +template +MSHADOW_XINLINE T FloatToQuantized(float input, float min_range, float max_range) { + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + float real_range = MaxAbs(min_range, max_range); + float quantized_range = MinAbs(MaxValue(), MinValue()); + float scale = quantized_range / real_range; + return Sign(input) * Min(Abs(input) * scale + 0.5f, quantized_range); +} + +template +MSHADOW_XINLINE float QuantizedToFloat(T input, float min_range, float max_range) { + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + float quantized_range = MinAbs(MinValue(), MaxValue()); + float real_range = MaxAbs(min_range, max_range); + float scale = real_range / quantized_range; + return input * scale; +} + +struct QuantizedToFloatStruct { + template + MSHADOW_XINLINE static void Map(int i, float *output, const T *input, + const float *range_min, const float *range_max) { + output[i] = QuantizedToFloat(input[i], *range_min, *range_max); + } +}; + +template +MSHADOW_XINLINE T2 RequantizeInNewRange(T1 input, float min_input, float max_input, + float min_new, float max_new) { + const float input_float = QuantizedToFloat(input, min_input, max_input); + return FloatToQuantized(input_float, min_new, max_new); +} + +template +MSHADOW_XINLINE void RequantizeManyInNewRange(size_t count, T2* output, const T1 *input, + float input_min, float input_max, + float actual_min, float actual_max) { + for (size_t i = 0; i < count; ++i) { + const float input_float = + QuantizedToFloat(input[i], input_min, input_max); + output[i] = FloatToQuantized(input_float, actual_min, actual_max); + } +} + +/*! + * \brief Get the scaling factor for converting type T to float. + */ +template +MSHADOW_XINLINE float FloatForOneQuantizedLevel(float range_min, float range_max) { + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + const int64_t highest = static_cast(MaxValue()); + const int64_t lowest = static_cast(MinValue()); + const float float_for_one_quantized_level = + (range_max - range_min) / (highest - lowest); + return float_for_one_quantized_level; +} + +template +MSHADOW_XINLINE void QuantizationRangeForMultiplication(float min_a, float max_a, + float min_b, float max_b, + float* min_c, float* max_c) { + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + const float a_float_for_one_quant_level = + FloatForOneQuantizedLevel(min_a, max_a); + const float b_float_for_one_quant_level = + FloatForOneQuantizedLevel(min_b, max_b); + + const int64_t c_highest = + static_cast(MaxValue()); + const int64_t c_lowest = + static_cast(MinValue()); + const float c_float_for_one_quant_level = + a_float_for_one_quant_level * b_float_for_one_quant_level; + + *min_c = c_float_for_one_quant_level * c_lowest; + *max_c = c_float_for_one_quant_level * c_highest; +} + +struct QuantizationRangeForMultiplicationStruct { + MSHADOW_XINLINE static void Map(int i, + float *min_c, + float *max_c, + const float *min_a, + const float *max_a, + const float *min_b, + const float *max_b) { + QuantizationRangeForMultiplication( + min_a[i], max_a[i], min_b[i], max_b[i], min_c, max_c); + } +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZATION_UTILS_H_ diff --git a/src/operator/contrib/quantize-inl.h b/src/operator/quantization/quantize-inl.h similarity index 52% rename from src/operator/contrib/quantize-inl.h rename to src/operator/quantization/quantize-inl.h index 4d55b1b5c6d0..8b7a11cc5a89 100644 --- a/src/operator/contrib/quantize-inl.h +++ b/src/operator/quantization/quantize-inl.h @@ -22,8 +22,8 @@ * \file quantize-inl.h * \brief implementation of quantize operation */ -#ifndef MXNET_OPERATOR_CONTRIB_QUANTIZE_INL_H_ -#define MXNET_OPERATOR_CONTRIB_QUANTIZE_INL_H_ +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_ #include #include @@ -31,6 +31,7 @@ #include "../elemwise_op_common.h" #include "../mshadow_op.h" #include "../mxnet_op.h" +#include "./quantization_utils.h" namespace mxnet { namespace op { @@ -39,25 +40,47 @@ struct QuantizeParam : public dmlc::Parameter { int out_type; DMLC_DECLARE_PARAMETER(QuantizeParam) { DMLC_DECLARE_FIELD(out_type) + .add_enum("int8", mshadow::kInt8) .add_enum("uint8", mshadow::kUint8) .set_default(mshadow::kUint8) .describe("Output data type."); } }; -struct quantize { +// quantize float to uint8_t +struct quantize_unsigned { template MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, float *omax_range, const SrcDType *in, const float *imin_range, const float *imax_range, - double min_limit, double max_limit) { - float scale = (max_limit - min_limit) / (*imax_range - *imin_range); + const double min_limit, const double max_limit) { + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + const float scale = (max_limit - min_limit) / (*imax_range - *imin_range); out[i] = static_cast((in[i] - *imin_range) * scale + 0.5); *omin_range = *imin_range; *omax_range = *imax_range; } }; + +// keep zero-center +struct quantize_zero_centered { + template + MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, + float *omax_range, const SrcDType *in, + const float *imin_range, const float *imax_range, + const float quantized_range) { + float real_range = MaxAbs(*imin_range, *imax_range); + float scale = quantized_range / real_range; + SrcDType x = in[i]; + out[i] = static_cast( + Sign(x) * Min(Abs(x) * scale + 0.5f, quantized_range)); + *omin_range = -real_range; + *omax_range = real_range; + } +}; + template void QuantizeCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -66,16 +89,24 @@ void QuantizeCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; Stream *s = ctx.get_stream(); - // for now, only supports quantize from uint8 to float - // TODO(ziheng) consider add MSHADOW_INTEGER_TYPE_SWITCH - typedef uint8_t DstDType; - typedef float SrcDType; - Kernel::Launch(s, outputs[0].Size(), - outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), - inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), - std::numeric_limits::min(), std::numeric_limits::max()); + const QuantizeParam& param = nnvm::get(attrs.parsed); + if (param.out_type == mshadow::kUint8) { + Kernel::Launch(s, outputs[0].Size(), + outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), + inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), + MinValue(), MaxValue()); + } else if (param.out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch(s, outputs[0].Size(), + outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), + inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), + MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } } inline bool QuantizeShape(const nnvm::NodeAttrs& attrs, @@ -84,15 +115,14 @@ inline bool QuantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 3U); - CHECK(!shape_is_none(in_attrs->at(0))); for (size_t i = 1; i < 3; ++i) { - CHECK(shape_is_scalar(in_attrs->at(i))); + SHAPE_ASSIGN_CHECK(*in_attrs, i, TShape({1})); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1}); SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1}); - return true; + return !shape_is_none(out_attrs->at(0)); } inline bool QuantizeType(const nnvm::NodeAttrs& attrs, @@ -100,13 +130,17 @@ inline bool QuantizeType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 3U); - CHECK_EQ((*in_attrs)[0], mshadow::kFloat32) - << "`quantize` only supports float32 input for now"; - CHECK_EQ((*in_attrs)[1], mshadow::kFloat32) - << "the second input of `quantize` should be a tensor with type of float"; - CHECK_EQ((*in_attrs)[2], mshadow::kFloat32) - << "the third input of `quantize` should be a tensor with type of float"; - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); + const QuantizeParam& param = nnvm::get(attrs.parsed); + TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32); + if (param.out_type == mshadow::kUint8) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); + } else if (param.out_type == mshadow::kInt8) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); return (*in_attrs)[0] != -1; @@ -114,4 +148,4 @@ inline bool QuantizeType(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_CONTRIB_QUANTIZE_INL_H_ +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_INL_H_ diff --git a/src/operator/contrib/quantize.cc b/src/operator/quantization/quantize.cc similarity index 65% rename from src/operator/contrib/quantize.cc rename to src/operator/quantization/quantize.cc index 43d60d1dd83d..32eb952fa5d7 100644 --- a/src/operator/contrib/quantize.cc +++ b/src/operator/quantization/quantize.cc @@ -32,21 +32,37 @@ NNVM_REGISTER_OP(_contrib_quantize) .describe(R"code(Quantize a input tensor from float to `out_type`, with user-specified `min_range` and `max_range`. -[min_range, max_range] are scalar floats that spcify the range for -the input data. Each value of the tensor will undergo the following: +min_range and max_range are scalar floats that specify the range for +the input data. -`out[i] = (in[i] - min_range) * range(OUTPUT_TYPE) / (max_range - min_range)` +When out_type is `uint8`, the output is calculated using the following equation: -here `range(T) = numeric_limits::max() - numeric_limits::min()` -)code" ADD_FILELINE) +`out[i] = (in[i] - min_range) * range(OUTPUT_TYPE) / (max_range - min_range) + 0.5`, + +where `range(T) = numeric_limits::max() - numeric_limits::min()`. + +When out_type is `int8`, the output is calculate using the following equation +by keep zero centered for the quantized value: + +`out[i] = sign(in[i]) * min(abs(in[i] * scale + 0.5f, quantized_range)`, + +where +`quantized_range = MinAbs(max(int8), min(int8))` and +`scale = quantized_range / MaxAbs(min_range, max_range).` + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(3) .set_num_outputs(3) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "min_range", "max_range"}; + }) .set_attr("FInferShape", QuantizeShape) .set_attr("FInferType", QuantizeType) .set_attr("FCompute", QuantizeCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_quantize"}) -.add_argument("input", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") +.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") .add_argument("min_range", "NDArray-or-Symbol", "The minimum scalar value " "possibly produced for the input") .add_argument("max_range", "NDArray-or-Symbol", "The maximum scalar value " diff --git a/src/operator/contrib/quantize.cu b/src/operator/quantization/quantize.cu similarity index 100% rename from src/operator/contrib/quantize.cu rename to src/operator/quantization/quantize.cu diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc new file mode 100644 index 000000000000..5ec745ccdf31 --- /dev/null +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file quantization.cc + * \brief + */ +#include +#include +#include +#include + +namespace mxnet { +namespace op { + +using nnvm::Symbol; +using nnvm::Node; +using nnvm::NodePtr; +using nnvm::NodeEntry; +using nnvm::Graph; + +NodePtr CreateNode(std::string op_name, std::string node_name) { + NodePtr node = Node::Create(); + node->attrs.name = node_name; + if (op_name == "nullptr") { + node->attrs.op = nullptr; + // ugly workaround because VariableParam is not exposed + node->attrs.parsed = + nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed; + } else { + node->attrs.op = Op::Get(op_name); + } + return node; +} + +/*! + * \brief Insert a node named with node_name holding the op of op_name + * before the node current and after the node previous. + */ +NodePtr InsertNode(std::string op_name, + std::string node_name, NodePtr current, NodeEntry previous) { + NodePtr node = CreateNode(op_name, node_name); + node->inputs.emplace_back(previous); + current->inputs.emplace_back(NodeEntry{node, 0, 0}); + return node; +} + +std::vector OfflineParams(std::vector&& outputs, + std::unordered_set&& offline_params) { + std::string node_suffixs[3] = {"", "_min", "_max"}; + std::unordered_map mirror_map; + nnvm::NodeEntryMap entry_var; + auto need_offline = [&](NodePtr n) { + return n->op() && + (n->op()->name == "_contrib_quantize") && + n->inputs[0].node->is_variable() && + offline_params.count(n->inputs[0].node->attrs.name); + }; + DFSVisit(outputs, [&](const NodePtr& node) { + for (NodeEntry& e : node->inputs) { + if (need_offline(e.node)) { + std::string node_name = e.node->attrs.name; + if (!entry_var.count(e)) { + entry_var[e] = CreateNode("nullptr", node_name + node_suffixs[e.index]); + } + e.node = entry_var[e]; + e.index = 0; + e.version = 0; + } + } + }); + return outputs; +} + +inline bool NeedQuantize(NodePtr node, const std::unordered_set excluded_nodes) { + static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); + return quantized_op_map.count(node->op()) && !excluded_nodes.count(node); +} + +Graph QuantizeGraph(Graph &&src) { + static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); + static auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); + auto offline_params = src.GetAttr>("offline_params"); + auto excluded_nodes = src.GetAttr>("excluded_nodes"); + + // mirror_map stores the mapping from the currently visited graph to the newly created quantized + // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key + // node. The existing key's value may be updated with the newly created quantize/dequantize op. + std::unordered_map mirror_map; + DFSVisit(src.outputs, [&](const NodePtr& node) { + NodePtr new_node = Node::Create(); + // If the currently visited node needs quantization, insert a quantize op node before the + // current node and replace the current node with the quantized version in the new graph. + if (NeedQuantize(node, excluded_nodes)) { + auto fquantized_op = quantized_op_map[node->op()]; + // If the currently visited node's op registered the FQuantizedOp property, new_node is a + // quantizated version of a that op, such as quantized_conv2d. + new_node = fquantized_op(node->attrs); + + // add data into quantized op input + for (const auto& e : node->inputs) { + NodePtr mirror_node = mirror_map.at(e.node.get()); + NodeEntry mirror_entry = NodeEntry{ + mirror_node, e.index, e.version}; + // If the NodeEntry e's node does not need quantization, and (the mirror_node is a variable, + // or the mirror_node's op is not a quantize op), create quantize op, min op, and max op + // taking mirror_entry as input to generate a quantized NDArray. Save the mapping between + // e's source node and the newly created quantize op so that the quantize op can be + // reused next time when the same entry is visited again. + if (!NeedQuantize(e.node, excluded_nodes) && + (mirror_node->op() == nullptr || + mirror_node->op()->name != "_contrib_quantize")) { + NodePtr quantize_node = InsertNode("_contrib_quantize", + e.node->attrs.name + "_quantize", new_node, mirror_entry); + quantize_node->attrs.dict["out_type"] = "int8"; + quantize_node->op()->attr_parser(&(quantize_node->attrs)); + + NodePtr min_node = InsertNode("min", + e.node->attrs.name + "_min", quantize_node, mirror_entry); + min_node->op()->attr_parser(&(min_node->attrs)); + + NodePtr max_node = InsertNode("max", + e.node->attrs.name + "_max", quantize_node, mirror_entry); + max_node->op()->attr_parser(&(max_node->attrs)); + + mirror_map[e.node.get()] = std::move(quantize_node); + } else { + // If the entry e's node needs quantization, or mirror_entry is from a quantize op, + // simply add mirror_entry to the input of the new_node. + new_node->inputs.emplace_back(mirror_entry); + } + // the input should be `quantize` or quantized version op now + } + + // add min and max into quantized op input assume order of quantized op inputs is: + // data1, data2, ..., min1, max1, min2, max2, ... + for (const auto& e : node->inputs) { + NodePtr mirror_node = mirror_map.at(e.node.get()); + NodeEntry mirror_entry = NodeEntry{ + mirror_node, e.index, e.version}; + // for quantize node + uint32_t min_index = 1; + uint32_t max_index = 2; + if (quantized_op_map.count(e.node->op())) { + size_t num_outputs = e.node->num_outputs(); + min_index = num_outputs + 2 * e.index; + max_index = num_outputs + 2 * e.index + 1; + } else { + CHECK(mirror_node->op()->name == "_contrib_quantize") + << "The input is not quantize or quantized_op"; + } + new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + } + + // If the new_node op registered attr FNeedRequantize, insert requantize node after it. + // Here it's assumed that the quantized_op node only produces three outputs: + // out_data, min_range, and max_range. + if (need_requantize_map.count(new_node->op()) > 0 + && need_requantize_map[new_node->op()](new_node->attrs)) { + NodePtr requantize_node = Node::Create(); + requantize_node->attrs.op = Op::Get("_contrib_requantize"); + requantize_node->attrs.name = "requantize_" + node->attrs.name; + if (requantize_node->op()->attr_parser != nullptr) { + requantize_node->op()->attr_parser(&(requantize_node->attrs)); + } + for (size_t i = 0; i < 3; ++i) { + requantize_node->inputs.emplace_back(NodeEntry{new_node, static_cast(i), 0}); + } + new_node = requantize_node; + } + } else { + // If the currently visited node does not need quantization, copy the current node to become + // the new_node. Meanwhile, check whether any inputs of the current node need quantization + // (e.g., a quantized_conv2d node), and insert a dequantize op node in the new graph if there + // are any. Otherwise, simply add a copy of the current node's entry to the inputs of + // the new_node. + *new_node = *node; + new_node->inputs.clear(); + for (const auto& e : node->inputs) { + NodePtr mirror_node = mirror_map.at(e.node.get()); + NodeEntry mirror_entry = NodeEntry{ + mirror_node, e.index, e.version}; + size_t num_outputs = e.node->num_outputs(); + uint32_t min_index = num_outputs + 2 * e.index; + uint32_t max_index = num_outputs + 2 * e.index + 1; + + // if input node is quantized operator, add dequantize node + if (NeedQuantize(e.node, excluded_nodes)) { + NodePtr dequantize_node = CreateNode("_contrib_dequantize", + e.node->attrs.name + "_dequantize"); + dequantize_node->inputs.emplace_back(mirror_entry); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); + + new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); + mirror_map[e.node.get()] = std::move(dequantize_node); + } else { + new_node->inputs.emplace_back(NodeEntry{mirror_node, e.index, e.version}); + } + } + } + mirror_map[node.get()] = std::move(new_node); + }); + + std::vector outputs; + for (const auto& e : src.outputs) { + if (quantized_op_map.count(e.node->op())) { + NodePtr mirror_node = mirror_map.at(e.node.get()); + NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; + size_t num_inputs = e.node->num_inputs(); + uint32_t min_index = num_inputs + 2 * e.index; + uint32_t max_index = num_inputs + 2 * e.index + 1; + + NodePtr dequantize_node = CreateNode("_contrib_dequantize", + e.node->attrs.name + "_dequantize"); + dequantize_node->inputs.emplace_back(mirror_entry); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); + outputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); + } else { + outputs.emplace_back(NodeEntry{mirror_map.at(e.node.get()), e.index, e.version}); + } + } + + if (!offline_params.empty()) outputs = + OfflineParams(std::move(outputs), std::move(offline_params)); + + Graph ret; + ret.outputs = std::move(outputs); + return ret; +} + +Graph SetCalibTableToQuantizedGraph(Graph&& g) { + static const auto& flist_outputs = + nnvm::Op::GetAttr("FListOutputNames"); + static const auto& need_requantize_map = + nnvm::Op::GetAttr("FNeedRequantize"); + const auto& calib_table = + g.GetAttr>>("calib_table"); + DFSVisit(g.outputs, [&](const NodePtr& node) { + // If the current op is requantize + // find the thresholds from the calibration table with the key equal + // to the current op's input node name, e.g. a quantized_conv2d node. + if (node->op() != nullptr && node->op()->name == "_contrib_requantize") { + NodePtr quantized_op_node = node->inputs[0].node; + CHECK(quantized_op_node->op() != nullptr) << quantized_op_node->attrs.name + << " must be an quantized op node"; + CHECK(need_requantize_map.count(quantized_op_node->op()) > 0 + && need_requantize_map[quantized_op_node->op()](quantized_op_node->attrs)) + << quantized_op_node->attrs.name << " op must register FNeedRequantize attr" + " and the attr func should return true"; + std::string out_data_name = quantized_op_node->attrs.name + "_"; + auto list_output_names_func = flist_outputs.get(quantized_op_node->op(), nullptr); + // Here it's assumed that the quantized_op node only produces three outputs: + // out_data, min_range, and max_range. So we want to get the pre-calculated min_calib_range + // and max_calib_range from the calibration table for out_data. Here we create the output + // data name same as its constructed in GraphExecutor::ExecuteMonCallback. + if (list_output_names_func != nullptr) { + std::vector names = list_output_names_func(quantized_op_node->attrs); + CHECK_EQ(names.size(), 3U) << "ListOutputNames is expected to return three string for" + " quantized operators"; + out_data_name += names[0]; + } else { + out_data_name += "0"; + } + const auto calib_table_iter = calib_table.find(out_data_name); + if (calib_table_iter != calib_table.end()) { + node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); + node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); + node->op()->attr_parser(&(node->attrs)); + } + } + }); + return g; +} + +NNVM_REGISTER_PASS(QuantizeGraph) +.describe("") +.set_body(QuantizeGraph) +.set_change_graph(true); + +NNVM_REGISTER_PASS(SetCalibTableToQuantizedGraph) +.describe("") +.set_body(SetCalibTableToQuantizedGraph) +.set_change_graph(true); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc new file mode 100644 index 000000000000..d7dc9fe4dbd8 --- /dev/null +++ b/src/operator/quantization/quantized_conv.cc @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_conv.cc + * \brief + * \author Ziheng Jiang, Jun Wu +*/ +#include "../nn/convolution-inl.h" + +namespace mxnet { +namespace op { + +bool QuantizedConvShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + using namespace mshadow; + const ConvolutionParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.num_group, 1U) << "quantized_conv only supports num_group=1 for now"; + CHECK_EQ(in_shape->size(), param.no_bias? 6U : 9U); + CHECK_EQ(out_shape->size(), 3U); + if (param.layout.has_value()) { + CHECK_EQ(param.layout.value(), mshadow::kNCHW) << "quantized_conv only supports NCHW for now"; + } + CHECK_EQ(param.kernel.ndim(), 2U) << "quantized_conv only supports 2D convolution for now"; + CHECK(param.dilate.ndim() == 0U || param.dilate.Size() == 1U) + << "quantized_conv only supports dilation=1 for all dimensions"; + const TShape& dshape = in_shape->at(0); + CHECK_EQ(dshape.ndim(), 4U); + if (dshape.ndim() == 0U) return false; + + const int N = 0, H = 2, W = 3, C = 1; + CHECK_EQ(dshape[C] % 4, 0U) + << "for 8bit cudnn conv, the number of channel must be multiple of 4"; + CHECK_EQ(param.num_filter % 4, 0U) + << "for 8bit cudnn conv, the number of channel must be multiple of 4"; + + TShape wshape{0, 0, 0, 0}; + wshape[N] = param.num_filter; + wshape[H] = param.kernel[0]; + wshape[W] = param.kernel[1]; + wshape[C] = dshape[C]; + SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape); + const int start = param.no_bias? 2 : 3; + const int end = param.no_bias? 6 : 9; + for (int i = start; i < end; ++i) { + SHAPE_ASSIGN_CHECK(*in_shape, i, TShape{1}); + } + if (!param.no_bias) { + SHAPE_ASSIGN_CHECK(*in_shape, 2, Shape1(param.num_filter)); + } + + auto AddPad = [](index_t dsize, index_t pad) { return dsize + 2 * pad; }; + TShape oshape{1, 1, 1, 1}; + oshape[N] = dshape[N]; + oshape[C] = wshape[N]; + oshape[H] = (AddPad(dshape[H], param.pad[0]) - wshape[H]) / param.stride[0] + 1; + oshape[W] = (AddPad(dshape[W], param.pad[1]) - wshape[W]) / param.stride[1] + 1; + + SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); + SHAPE_ASSIGN_CHECK(*out_shape, 1, TShape({1})); + SHAPE_ASSIGN_CHECK(*out_shape, 2, TShape({1})); + return true; +} + +bool QuantizedConvType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const ConvolutionParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_type->size(), param.no_bias? 6U : 9U); + CHECK_EQ(out_type->size(), 3U); + TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*in_type, 1, mshadow::kInt8); + if (!param.no_bias) { + TYPE_ASSIGN_CHECK(*in_type, 2, mshadow::kInt8); + } + + const size_t start = param.no_bias? 2 : 3; + const size_t end = param.no_bias? 6 : 9; + for (size_t i = start; i < end; ++i) { + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32); + } + + TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt32); + TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); + return true; +} + +NNVM_REGISTER_OP(_contrib_quantized_conv) +.describe(R"code(Convolution operator for input, weight and bias data type of int8, +and accumulates in type int32 for the output. For each argument, two more arguments of type +float32 must be provided representing the thresholds of quantizing argument from data +type float32 to int8. The final outputs contain the convolution result in int32, and min +and max thresholds representing the threholds for quantizing the float32 output into int32. + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE) +.set_num_inputs( + [](const NodeAttrs& attrs) { + const ConvolutionParam& param = nnvm::get(attrs.parsed); + return param.no_bias? 6 : 9; + }) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ConvolutionParam& param = nnvm::get(attrs.parsed); + if (param.no_bias) { + return std::vector{"data", "weight", "min_data", "max_data", + "min_weight", "max_weight"}; + } else { + return std::vector{"data", "weight", "bias", "min_data", "max_data", + "min_weight", "max_weight", "min_bias", "max_bias"}; + } + }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "min_output", "max_output"}; + }) +.set_attr("FInferShape", QuantizedConvShape) +.set_attr("FInferType", QuantizedConvType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector(1, ResourceRequest::kTempSpace); + }) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("data", "NDArray-or-Symbol", "Input data.") +.add_argument("weight", "NDArray-or-Symbol", "weight.") +.add_argument("bias", "NDArray-or-Symbol", "bias.") +.add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.") +.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.") +.add_argument("min_weight", "NDArray-or-Symbol", "Minimum value of weight.") +.add_argument("max_weight", "NDArray-or-Symbol", "Maximum value of weight.") +.add_argument("min_bias", "NDArray-or-Symbol", "Minimum value of bias.") +.add_argument("max_bias", "NDArray-or-Symbol", "Maximum value of bias.") +.add_arguments(ConvolutionParam::__FIELDS__()); + +NNVM_REGISTER_OP(Convolution) +.set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_conv"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_conv.cu b/src/operator/quantization/quantized_conv.cu new file mode 100644 index 000000000000..2db5416309b5 --- /dev/null +++ b/src/operator/quantization/quantized_conv.cu @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_conv.cu + * \brief + * \author Ziheng Jiang, Jun Wu +*/ +#include "../nn/convolution-inl.h" +#include "./quantization_utils.h" +#include "../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +// value + bias_value * (range1 / limit_range1) * (limit_range2 / range2) +struct QuantizedBiasAddKernel { + MSHADOW_XINLINE static void Map(int i, size_t bias_size, int32_t *out, + const int8_t *bias, const float *min_out, + const float *max_out, const float *min_bias, + const float *max_bias, const size_t spatial_size) { + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + float float_for_one_out_quant = + MaxAbs(*min_out, *max_out) / static_cast(MaxValue()); + float float_for_one_bias_quant = + MaxAbs(*min_bias, *max_bias) / static_cast(MaxValue()); + const size_t channel_id = (i / spatial_size) % bias_size; + out[i] = (out[i] * float_for_one_out_quant + + bias[channel_id] * float_for_one_bias_quant) / + float_for_one_out_quant; + } +}; + +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 +template +class QuantizedCuDNNConvOp { + public: + QuantizedCuDNNConvOp() { + CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&data_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); + CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc_)); + } + + void Init(const ConvolutionParam& param, + const OpContext& ctx, + const std::vector& in_shape, + const std::vector& out_shape) { + param_ = param; + CHECK_EQ(param_.kernel.ndim(), 2U) + << "QuantizedCuDNNConvOp only supports 2D convolution for now"; + if (param_.layout.has_value()) { + CHECK_EQ(param_.layout.value(), mshadow::kNCHW) + << "QuantizedConvOp only supports NCHW for now"; + } + if (param_.stride.ndim() == 0U) param_.stride = mshadow::Shape2(1, 1); + if (param_.dilate.ndim() == 0U) param_.dilate = mshadow::Shape2(1, 1); + if (param_.pad.ndim() == 0U) param_.pad = mshadow::Shape2(0, 0); + N = 0, H = 2, W = 3, C = 1; + src_type_ = mshadow::DataType::kCudnnFlag; + dst_type_ = mshadow::DataType::kCudnnFlag; + cmp_type_ = mshadow::DataType::kCudnnFlag; + algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + format_ = CUDNN_TENSOR_NHWC; + InitDescriptors(in_shape, out_shape); + GetTempSize(ctx); + } + + ~QuantizedCuDNNConvOp() { + CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(data_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); + CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc_)); + } + + void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + CHECK_EQ(param_.kernel.ndim(), 2U) + << "QuantizedCuDNNConvOp only supports 2D convolution for now"; + using namespace mshadow; + CHECK_EQ(in_data.size(), param_.no_bias? 6U : 9U); + CHECK_EQ(out_data.size(), 3U); + Stream *s = ctx.get_stream(); + CHECK_EQ(s->dnn_handle_ownership_, Stream::OwnHandle); + + const TBlob& data = in_data[0]; + const TBlob& filter = in_data[1]; + const TBlob& out = out_data[0]; + const TShape& dshape = data.shape_; + const TShape& fshape = filter.shape_; + const TShape& oshape = out.shape_; + + // allocate workspace + const int dev_id = ctx.run_ctx.ctx.dev_id; + const int dev_mask = gpu::kDevMask; + if (!param_.layout.has_value() || param_.layout.value() == mshadow::kNCHW) { + const size_t data_size = dshape.Size(); + const size_t weight_size = fshape.Size(); + const size_t output_size = oshape.Size(); + size_t total_temp_bytes = (workspace_ + data_size + weight_size) * sizeof(SrcType) + + output_size * (sizeof(DstType) + sizeof(int32_t)); + Tensor temp_space = + ctx.requested[0].get_space_typed(mshadow::Shape1(total_temp_bytes), s); + char* temp_dptr = temp_space.dptr_; + TBlob data_(reinterpret_cast(temp_dptr), + TShape({dshape[N], dshape[H], dshape[W], dshape[C]}), + dev_mask, DataType::kFlag, dev_id); + temp_dptr += data_size * sizeof(SrcType); + TBlob filter_(reinterpret_cast(temp_dptr), + TShape({fshape[N], fshape[H], fshape[W], fshape[C]}), + dev_mask, DataType::kFlag, dev_id); + temp_dptr += weight_size * sizeof(SrcType); + + // input: [NCHW] => [NHWC](batch, in_height, in_width, in_channels) + // filter: [NCHW] => [NHWC](out_channels, filter_height, filter_width, in_channels) + TransposeImpl(ctx.run_ctx, data, data_, TShape({N, H, W, C})); + TransposeImpl(ctx.run_ctx, filter, filter_, TShape({N, H, W, C})); + TBlob out_(reinterpret_cast(temp_dptr), + TShape({oshape[N], oshape[H], oshape[W], oshape[C]}), + dev_mask, DataType::kFlag, dev_id); + temp_dptr += output_size * sizeof(DstType); + TBlob out_tcast(reinterpret_cast(temp_dptr), + TShape({oshape[N], oshape[H], oshape[W], oshape[C]}), + dev_mask, DataType::kFlag, dev_id); + temp_dptr += output_size * sizeof(int32_t); + // input: [NHWC](batch, in_height, in_width, in_channels) + // filter: [HWNC](out_channels, filter_height, filter_width, in_channels) + // output: [NHWC](batch, out_height, out_width, out_channels) + + CUDNN_CALL(cudnnConvolutionForward(s->dnn_handle_, + &alpha_, + data_desc_, + data_.dptr_, + filter_desc_, + filter_.dptr_, + conv_desc_, + algo_, + temp_dptr, + workspace_byte_, + &beta_, + out_desc_, + out_.dptr_)); + + Tensor out_tensor = out_.FlatTo1D(s); + Tensor out_tcast_tensor = out_tcast.FlatTo1D(s); + Assign(out_tcast_tensor, kWriteTo, mshadow::expr::tcast(out_tensor)); + // output: [NHWC](batch, out_height, out_width, out_channels) => [NCHW] + TransposeImpl(ctx.run_ctx, out_tcast, out, TShape({0, 3, 1, 2})); + } else { + LOG(FATAL) << "quantized_conv only supports NCHW for now"; + } + + // calculate the min/max range for out_data as it's a multiplication + // of in_data[0] and in_data[1]. Need to rescale the min/max range of out_data + // based on the min/max ranges of in_data[0] and in_data[1]. + const size_t num_inputs = param_.no_bias ? 2 : 3; + mxnet_op::Kernel::Launch(s, 1, + out_data[1].dptr(), out_data[2].dptr(), + in_data[num_inputs].dptr(), in_data[num_inputs+1].dptr(), + in_data[num_inputs+2].dptr(), in_data[num_inputs+3].dptr()); + + if (!param_.no_bias) { + if (param_.layout.has_value()) { + CHECK_EQ(param_.layout.value(), mshadow::kNCHW) + << "quantized_conv only supports NCHW when there is a bias"; + } + const TBlob& bias = in_data[2]; + mxnet_op::Kernel::Launch(s, out.Size(), + bias.Size(), out.dptr(), bias.dptr(), + out_data[1].dptr(), out_data[2].dptr(), + in_data[7].dptr(), in_data[8].dptr(), + oshape[2] * oshape[3]); + } + } + + void InitDescriptors(const std::vector& in_shape, + const std::vector& out_shape) { + const TShape& dshape = in_shape[0]; + const TShape& kshape = in_shape[1]; + const TShape& oshape = out_shape[0]; + CUDNN_CALL(cudnnSetConvolution2dDescriptor(conv_desc_, + param_.pad[0], + param_.pad[1], + param_.stride[0], + param_.stride[1], + 1, + 1, + CUDNN_CROSS_CORRELATION, + cmp_type_)); + + CUDNN_CALL(cudnnSetTensor4dDescriptor(data_desc_, + format_, + src_type_, + dshape[N], + dshape[C], + dshape[H], + dshape[W])); + CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_, + format_, + dst_type_, + oshape[N], + oshape[C], + oshape[H], + oshape[W])); + CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_, + src_type_, + format_, + kshape[N], + kshape[C], + kshape[H], + kshape[W])); + } + + void GetTempSize(const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, + data_desc_, + filter_desc_, + conv_desc_, + out_desc_, + algo_, + &workspace_byte_)); + workspace_ = workspace_byte_ / sizeof(SrcType) + 1; + } + + private: + ConvolutionParam param_; + size_t workspace_; + size_t workspace_byte_; + cudnnDataType_t src_type_; + cudnnDataType_t dst_type_; + cudnnDataType_t cmp_type_; + cudnnTensorFormat_t format_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t data_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnTensorDescriptor_t out_desc_; + cudnnConvolutionFwdAlgo_t algo_; + uint32_t N, H, W, C; + float alpha_ = 1.0f; + float beta_ = 0.0f; +}; // class QuantizedCuDNNConvOp +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 + +void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const ConvolutionParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.kernel.ndim(), 2U) + << "QuantizedConvForward only supports 2D convolution for now"; +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 + typedef QuantizedCuDNNConvOp QuantizedConvOpInt8; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local QuantizedConvOpInt8 op; +#else + static MX_THREAD_LOCAL QuantizedConvOpInt8 op; +#endif // DMLC_CXX11_THREAD_LOCAL + op.Init(param, ctx, {inputs[0].shape_, inputs[1].shape_}, {outputs[0].shape_}); + op.Forward(ctx, inputs, req, outputs); +#else + LOG(FATAL) << "QuantizedConvForward only supports cudnnConvolutionForward for now"; +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 +} + +NNVM_REGISTER_OP(_contrib_quantized_conv) +.set_attr("FCompute", QuantizedConvForwardGPU); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_flatten-inl.h b/src/operator/quantization/quantized_flatten-inl.h new file mode 100644 index 000000000000..95f366154022 --- /dev/null +++ b/src/operator/quantization/quantized_flatten-inl.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_flatten-inl.h + * \brief implementation of quantized flatten operation + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_FLATTEN_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_FLATTEN_INL_H_ + +#include +#include +#include +#include "../elemwise_op_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "./quantization_utils.h" + +namespace mxnet { +namespace op { + +// keep zero-center +struct quantized_flatten { + template + MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, + float *omax_range, const SrcDType *in, + const float *imin_range, const float *imax_range) { + out[i] = in[i]; + omin_range[0] = imin_range[0]; + omax_range[0] = imax_range[0]; + } +}; + +template +void QuantizedFlattenCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 3U); + CHECK_EQ(req.size(), 3U); + if (req[0] == kWriteInplace && req[1] == kWriteInplace && req[2] == kWriteInplace) return; + using namespace mshadow; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + + typedef int8_t DstDType; + typedef int8_t SrcDType; + Kernel::Launch(s, outputs[0].Size(), + outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), + inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr()); +} + +inline bool QuantizedFlattenShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 3U); + + const TShape &dshape = (*in_attrs)[0]; + if (shape_is_none(dshape)) return false; + + uint32_t target_dim = 1; + for (uint32_t i = 1; i < dshape.ndim(); ++i) { + target_dim *= dshape[i]; + } + + SHAPE_ASSIGN_CHECK(*in_attrs, 1, TShape{1}); + SHAPE_ASSIGN_CHECK(*in_attrs, 2, TShape{1}); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(dshape[0], target_dim)); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1}); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1}); + return true; +} + +inline bool QuantizedFlattenType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 3U); + TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); + return (*in_attrs)[0] != -1; +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_FLATTEN_INL_H_ diff --git a/src/operator/quantization/quantized_flatten.cc b/src/operator/quantization/quantized_flatten.cc new file mode 100644 index 000000000000..3f426a59bdd2 --- /dev/null +++ b/src/operator/quantization/quantized_flatten.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_flatten.cc + * \brief + */ +#include +#include "./quantized_flatten-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_quantized_flatten) +.set_num_inputs(3) +.set_num_outputs(3) +.set_attr("FInferShape", QuantizedFlattenShape) +.set_attr("FInferType", QuantizedFlattenType) +.set_attr("FCompute", QuantizedFlattenCompute) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "min_data", "max_data"}; + }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "min_output", "max_output"}; + }) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 1}, {2, 2}}; + }) +.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") +.add_argument("min_data", "NDArray-or-Symbol", "The minimum scalar value " + "possibly produced for the data") +.add_argument("max_data", "NDArray-or-Symbol", "The maximum scalar value " + "possibly produced for the data"); + +NNVM_REGISTER_OP(Flatten) +.set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_flatten"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_flatten.cu b/src/operator/quantization/quantized_flatten.cu new file mode 100644 index 000000000000..4f0c8f93ab06 --- /dev/null +++ b/src/operator/quantization/quantized_flatten.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_flatten.cu + * \brief + */ +#include "./quantized_flatten-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_quantized_flatten) +.set_attr("FCompute", QuantizedFlattenCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc new file mode 100644 index 000000000000..e334fe7ec9b2 --- /dev/null +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_fully_connected.cc + * \brief + * \author Ziheng Jiang, Jun Wu +*/ +#include "../nn/fully_connected-inl.h" + +namespace mxnet { +namespace op { + +bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + CHECK(param.flatten) << "QuantizedFullyConnectedOp only supports flatten=true for now"; + using namespace mshadow; + uint32_t num_inputs = param.no_bias ? 2 : 3; + CHECK_EQ(in_shape->size(), num_inputs * 3); + CHECK_EQ(out_shape->size(), 3U); + + CHECK(!shape_is_none(in_shape->at(0))) + << "QuantizedFullyConnectedOp input data shape must be given"; + const TShape& dshape = in_shape->at(0); + TShape wshape = Shape2(param.num_hidden, dshape.ProdShape(1, dshape.ndim())); + SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape); + if (!param.no_bias) { + TShape bshape = Shape1(param.num_hidden); + SHAPE_ASSIGN_CHECK(*in_shape, 2, bshape); + } + + for (size_t i = num_inputs; i < 3 * num_inputs; ++i) { + SHAPE_ASSIGN_CHECK(*in_shape, i, TShape{1}); + } + + SHAPE_ASSIGN_CHECK(*out_shape, 0, TShape({dshape[0], wshape[0]})); + SHAPE_ASSIGN_CHECK(*out_shape, 1, TShape({1})); + SHAPE_ASSIGN_CHECK(*out_shape, 2, TShape({1})); + return true; +} + +bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + uint32_t num_inputs = param.no_bias ? 2 : 3; + CHECK_EQ(in_type->size(), num_inputs * 3); + CHECK_EQ(out_type->size(), 3U); + + for (size_t i = 0; i < num_inputs; ++i) { + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kInt8); + } + for (size_t i = num_inputs; i < 3 * num_inputs; ++i) { + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32); + } + + TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt32); + TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); + return true; +} + +NNVM_REGISTER_OP(_contrib_quantized_fully_connected) +.describe(R"code(Fully Connected operator for input, weight and bias data type of int8, +and accumulates in type int32 for the output. For each argument, two more arguments of type +float32 must be provided representing the thresholds of quantizing argument from data +type float32 to int8. The final outputs contain the convolution result in int32, and min +and max thresholds representing the threholds for quantizing the float32 output into int32. + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE) +.set_num_inputs( + [](const NodeAttrs& attrs) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + return param.no_bias? 6 : 9; + }) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + if (param.no_bias) { + return std::vector{"data", "weight", "min_data", "max_data", + "min_weight", "max_weight"}; + } else { + return std::vector{"data", "weight", "bias", "min_data", "max_data", + "min_weight", "max_weight", "min_bias", "max_bias"}; + } + }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "min_output", "max_output"}; + }) +.set_attr("FInferShape", QuantizedFullyConnectedShape) +.set_attr("FInferType", QuantizedFullyConnectedType) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("data", "NDArray-or-Symbol", "Input data.") +.add_argument("weight", "NDArray-or-Symbol", "weight.") +.add_argument("bias", "NDArray-or-Symbol", "bias.") +.add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.") +.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.") +.add_argument("min_weight", "NDArray-or-Symbol", "Minimum value of weight.") +.add_argument("max_weight", "NDArray-or-Symbol", "Maximum value of weight.") +.add_argument("min_bias", "NDArray-or-Symbol", "Minimum value of bias.") +.add_argument("max_bias", "NDArray-or-Symbol", "Maximum value of bias.") +.add_arguments(FullyConnectedParam::__FIELDS__()); + +NNVM_REGISTER_OP(FullyConnected) +.set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_fully_connected"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_fully_connected.cu b/src/operator/quantization/quantized_fully_connected.cu new file mode 100644 index 000000000000..ac7ba1e21df8 --- /dev/null +++ b/src/operator/quantization/quantized_fully_connected.cu @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_fully_connected.cu + * \brief + * \author Ziheng Jiang, Jun Wu +*/ +#include "./quantization_utils.h" +#include "../mxnet_op.h" +#include "../nn/fully_connected-inl.h" + +namespace mxnet { +namespace op { + +// value + bias_value * (range1 / limit_range1) * (limit_range2 / range2) +struct QuantizedBiasAddKernel { + MSHADOW_XINLINE static void Map(int i, size_t k, int32_t *out, + const int8_t *bias, const float *min_out, + const float *max_out, const float *min_bias, + const float *max_bias) { + typedef int32_t T1; + typedef int8_t T2; + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + float float_for_one_out_quant = + MaxAbs(*min_out, *max_out) / static_cast(MaxValue()); + float float_for_one_bias_quant = + MaxAbs(*min_bias, *max_bias) / static_cast(MaxValue()); + out[i] = (out[i] * float_for_one_out_quant + + bias[i%k] * float_for_one_bias_quant) / + float_for_one_out_quant; + } +}; + +template +void QuantizedFullyConnectedForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + using namespace mshadow; + using namespace mxnet_op; + size_t num_inputs = param.no_bias ? 2 : 3; + CHECK_EQ(inputs.size(), num_inputs * 3); + CHECK_EQ(outputs.size(), 3U); + Stream *s = ctx.get_stream(); + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle); + const TBlob& data = inputs[0]; + const TBlob& weight = inputs[1]; + const TBlob& out = outputs[0]; + TShape dshape = data.shape_; + TShape wshape = weight.shape_; + TShape oshape = out.shape_; + // (m, n) * (k, n).T = (m, k) + // A * B.T = C + + // row_C = col_C(T) = cublas(col_B * col_A(T)) = cublas(row_B(T), row_A) + // row_C = col_C(T) = cublas(col_B(T) * col_A(T)) = cublas(row_B, row_A) + const int m = dshape[0], n = dshape.ProdShape(1, dshape.ndim()), k = wshape[0]; + CmpType alpha = 1.0f; + CmpType beta = 0.0f; + const cudaDataType src_type = mshadow::DataType::kCudaFlag; + const cudaDataType dst_type = mshadow::DataType::kCudaFlag; + const cudaDataType cmp_type = mshadow::DataType::kCudaFlag; + CUBLAS_CALL(cublasGemmEx(s->blas_handle_, + CUBLAS_OP_T, + CUBLAS_OP_N, + k, + m, + n, + &alpha, + weight.dptr_, + src_type, + n, + data.dptr_, + src_type, + n, + &beta, + out.dptr_, + dst_type, + k, + cmp_type, + CUBLAS_GEMM_DFALT)); + + Kernel::Launch(s, 1, + outputs[1].dptr(), outputs[2].dptr(), + inputs[num_inputs].dptr(), inputs[num_inputs+1].dptr(), + inputs[num_inputs+2].dptr(), inputs[num_inputs+3].dptr()); + + if (!param.no_bias) { + const TBlob& bias = inputs[2]; + Kernel::Launch(s, out.Size(), + k, out.dptr(), bias.dptr(), + outputs[1].dptr(), outputs[2].dptr(), + inputs[7].dptr(), inputs[8].dptr()); + } +} + +NNVM_REGISTER_OP(_contrib_quantized_fully_connected) +.set_attr("FCompute", QuantizedFullyConnectedForwardGPU); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc new file mode 100644 index 000000000000..71f4e738161d --- /dev/null +++ b/src/operator/quantization/quantized_pooling.cc @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_pooling.cc +*/ +#include +#include "../nn/pooling-inl.h" + +namespace mxnet { +namespace op { + +bool QuantizedPoolingShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 3U); + if (shape_is_none(in_shape->at(0))) return false; + const TShape &dshape = (*in_shape)[0]; + CHECK_EQ(dshape.ndim(), 4U) + << "quantized_pooling: Input data should be 4D in " + << "(batch, channel, y, x)"; + // NCHW layout + const int N = 0, H = 2, W = 3, C = 1; + TShape oshape(4); + CHECK_EQ(param.kernel.ndim(), 2) << "QuantizedPoolingOp only supports 2D pooling for now"; + CHECK(param.kernel[0] <= dshape[H] + 2 * param.pad[0]) + << "kernel size (" << param.kernel[0] + << ") exceeds input (" << dshape[H] + << " padded to " << (dshape[H] + 2*param.pad[0]) << ")"; + CHECK(param.kernel[1] <= dshape[W] + 2 * param.pad[1]) + << "kernel size (" << param.kernel[1] + << ") exceeds input (" << dshape[W] + << " padded to " << (dshape[W] + 2*param.pad[1]) << ")"; + // only support valid convention + oshape[N] = dshape[N]; + oshape[C] = dshape[C]; + if (param.global_pool) { + oshape[H] = 1; + oshape[W] = 1; + } else { + oshape[H] = 1 + (dshape[H] + 2 * param.pad[0] - param.kernel[0]) / + param.stride[0]; + oshape[W] = 1 + (dshape[W] + 2 * param.pad[1] - param.kernel[1]) / + param.stride[1]; + } + + SHAPE_ASSIGN_CHECK(*in_shape, 1, TShape{1}); + SHAPE_ASSIGN_CHECK(*in_shape, 2, TShape{1}); + + out_shape->clear(); + out_shape->push_back(oshape); + out_shape->push_back(TShape{1}); + out_shape->push_back(TShape{1}); + return true; +} + +bool QuantizedPoolingType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_type->size(), 3U); + CHECK_EQ(out_type->size(), 3U); + if (param.pool_type == pool_enum::kMaxPooling || param.pool_type == pool_enum::kAvgPooling) { + TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt8); + } else { + LOG(FATAL) << "QuantizedPoolingOp only supports pool_type=max/avg for now"; + } + TYPE_ASSIGN_CHECK(*in_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_type, 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); + return true; +} + +NNVM_REGISTER_OP(_contrib_quantized_pooling) +.set_num_inputs(3) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "min_data", "max_data"}; + }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "min_output", "max_output"}; + }) +.set_attr("FInferShape", QuantizedPoolingShape) +.set_attr("FInferType", QuantizedPoolingType) +.set_attr("FNeedRequantize", + [](const NodeAttrs& attrs) { + const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK(param.pool_type == pool_enum::kMaxPooling || param.pool_type == pool_enum::kAvgPooling) + << "QuantizedPoolingOp only supports pool_type=max/avg for now"; + return false; + }) +.add_argument("data", "NDArray-or-Symbol", "Input data.") +.add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.") +.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.") +.add_arguments(PoolingParam::__FIELDS__()); + +NNVM_REGISTER_OP(Pooling) +.describe(R"code(Pooling operator for input and output data type of int8. +The input and output data comes with min and max thresholds for quantizing +the float32 data into int8. + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training. + This operator only supports `pool_type` of `avg` or `max`.)code" ADD_FILELINE) +.set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { + PoolingParam param; + param.Init(attrs.dict); + // TODO(junwu): Uncomment the following line and remove the above lines + // after pooling op is refactored + // const PoolingParam& param = nnvm::get(attrs.parsed); + nnvm::NodePtr node = nnvm::Node::Create(); + if (param.pool_type == pool_enum::kMaxPooling || param.pool_type == pool_enum::kAvgPooling) { + node->attrs.op = Op::Get("_contrib_quantized_pooling"); + node->attrs.name = "quantized_" + attrs.name; + } else { + node->attrs.op = Op::Get("Pooling"); + node->attrs.name = attrs.name; + } + node->attrs.dict = attrs.dict; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantized_pooling.cu b/src/operator/quantization/quantized_pooling.cu new file mode 100644 index 000000000000..78011b885c53 --- /dev/null +++ b/src/operator/quantization/quantized_pooling.cu @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_pooling.cu +*/ +#include +#include +#include "../nn/pooling-inl.h" +#include "../mshadow_op.h" + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 +template +class QuantizedCuDNNPoolingOp { + public: + QuantizedCuDNNPoolingOp() { + CUDNN_CALL(cudnnCreatePoolingDescriptor(&pool_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); + } + + void Init(const PoolingParam& param, const TShape& dshape, const TShape& oshape) { + const int N = 0, H = 2, W = 3, C = 1; + const cudnnDataType_t dtype = mshadow::DataType::kCudnnFlag; + CHECK(param.kernel.ndim() == 2) << "Only support 2D pooling"; + if (param.pool_type == pool_enum::kMaxPooling) { + mode_ = CUDNN_POOLING_MAX; + } else if (param.pool_type == pool_enum::kAvgPooling) { + mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + } else { + LOG(FATAL) << "QuantizedCuDNNPoolingOp only supports pool_type=max/avg"; + } + CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_, + CUDNN_TENSOR_NCHW, + dtype, + dshape[N], + dshape[C], + dshape[H], + dshape[W])); + CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_, + CUDNN_TENSOR_NCHW, + dtype, + oshape[N], + oshape[C], + oshape[H], + oshape[W])); + CUDNN_CALL(cudnnSetPooling2dDescriptor(pool_desc_, + mode_, + CUDNN_NOT_PROPAGATE_NAN, + param.global_pool ? dshape[2] : param.kernel[0], + param.global_pool ? dshape[3] : param.kernel[1], + param.pad[0], + param.pad[1], + param.global_pool ? 1 : param.stride[0], + param.global_pool ? 1 :param.stride[1])); + } + + ~QuantizedCuDNNPoolingOp() { + CUDNN_CALL(cudnnDestroyTensorDescriptor(in_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); + CUDNN_CALL(cudnnDestroyPoolingDescriptor(pool_desc_)); + } + + void Forward(mshadow::Stream* s, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 3U); + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + float alpha = 1.0f; + float beta = 0.0f; + CUDNN_CALL(cudnnPoolingForward(s->dnn_handle_, + pool_desc_, + &alpha, + in_desc_, + inputs[0].dptr_, + &beta, + out_desc_, + outputs[0].dptr_)); + + Tensor omin_range = outputs[1].FlatTo1D(s); + Tensor omax_range = outputs[2].FlatTo1D(s); + ASSIGN_DISPATCH(omin_range, req[1], + F(inputs[1].FlatTo1D(s))); + ASSIGN_DISPATCH(omax_range, req[2], + F(inputs[2].FlatTo1D(s))); + } + + private: + cudnnPoolingMode_t mode_; + cudnnTensorDescriptor_t in_desc_; + cudnnTensorDescriptor_t out_desc_; + cudnnPoolingDescriptor_t pool_desc_; +}; // class QuantizedCuDNNPoolingOp +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 + +void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const PoolingParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.kernel.ndim(), 2U) + << "QuantizedPoolingForward only supports 2D convolution for now"; +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 +#if DMLC_CXX11_THREAD_LOCAL + static thread_local QuantizedCuDNNPoolingOp op; +#else + static MX_THREAD_LOCAL QuantizedCuDNNPoolingOp op; +#endif // DMLC_CXX11_THREAD_LOCAL + op.Init(param, {inputs[0].shape_}, {outputs[0].shape_}); + op.Forward(ctx.get_stream(), inputs, req, outputs); +#else + LOG(FATAL) << "QuantizedPoolingForward only supports cudnnPoolingForward for now"; +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 +} + +NNVM_REGISTER_OP(_contrib_quantized_pooling) +.set_attr("FCompute", QuantizedPoolingForwardGPU); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h new file mode 100644 index 000000000000..e07a149f8a6b --- /dev/null +++ b/src/operator/quantization/requantize-inl.h @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file requantize-inl.h + * \brief implementation of quantize operation + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_REQUANTIZE_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_REQUANTIZE_INL_H_ + +#include +#include +#include +#include "../elemwise_op_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "./quantization_utils.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct RequantizeParam : public dmlc::Parameter { + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + DMLC_DECLARE_PARAMETER(RequantizeParam) { + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to requantize the " + "int32 data into int8."); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to requantize the " + "int32 data into int8."); + } +}; + +inline bool RequantizeType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 3U); + TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kInt32); + TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); + return (*in_attrs)[0] != -1; +} + +struct RequantizeKernel { + template + MSHADOW_XINLINE static void Map(int i, T2 *output, float *omin_range, float *omax_range, + const T1 *input, const float *imin_range, const float *imax_range, const float real_range) { + const float input_float = QuantizedToFloat(input[i], *imin_range, *imax_range); + *omin_range = -real_range; + *omax_range = real_range; + output[i] = FloatToQuantized(input_float, -real_range, real_range); + } + + template + MSHADOW_XINLINE static void Map(int i, T2 *output, float *omin_range, float *omax_range, + const T1 *input, const float *imin_range, const float *imax_range, + const float *actual_min, const float *actual_max) { + Map(i, output, omin_range, omax_range, input, imin_range, imax_range, + MaxAbs(*actual_min, *actual_max)); + } +}; + +template +inline size_t ConfigReduce(mshadow::Stream* s, + const TShape& data_shape, + const TShape& out_shape, + TShape* src_shape, + TShape* dst_shape) { + BroadcastReduceShapeCompact(data_shape, out_shape, src_shape, dst_shape); + constexpr int NDim = 2; + CHECK_EQ(src_shape->ndim(), NDim); + CHECK_EQ(dst_shape->ndim(), NDim); + + return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape); +} + +template +void RequantizeForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + typedef int32_t SrcDType; + typedef int8_t DstDType; + Stream *s = ctx.get_stream(); + const RequantizeParam& param = + nnvm::get(attrs.parsed); + + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + Kernel::Launch(s, inputs[0].Size(), + outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), + inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), + MaxAbs(param.min_calib_range.value(), param.max_calib_range.value())); + } else { // model is not calibrated + TShape src_shape, dst_shape; + const size_t actual_float_size = sizeof(float); + const size_t actual_quantized_size = sizeof(SrcDType); + const size_t temp_reduce_size = ConfigReduce( + s, inputs[0].shape_, TShape({1}), &src_shape, &dst_shape); + Tensor temp_space = + ctx.requested[0].get_space_typed( + Shape1(2*actual_float_size+2*actual_quantized_size+temp_reduce_size), s); + Tensor actual_min_float( + reinterpret_cast(temp_space.dptr_), Shape1(1), s); + Tensor actual_max_float( + reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), s); + + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob actual_min_quantized(reinterpret_cast( + temp_space.dptr_ + 8), Shape1(1), xpu::kDevMask, dev_id); + TBlob actual_max_quantized(reinterpret_cast( + temp_space.dptr_ + 8) + 1, Shape1(1), xpu::kDevMask, dev_id); + Tensor workspace( + temp_space.dptr_+2*actual_float_size+2*actual_quantized_size, Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, actual_min_quantized.reshape(dst_shape), + kWriteTo, workspace, inputs[0].reshape(src_shape)); + Kernel::Launch(s, 1, + actual_min_float.dptr_, actual_min_quantized.dptr(), + inputs[1].dptr(), inputs[2].dptr()); + + broadcast::Reduce( + s, actual_max_quantized.reshape(dst_shape), + kWriteTo, workspace, inputs[0].reshape(src_shape)); + Kernel::Launch(s, 1, + actual_max_float.dptr_, actual_max_quantized.dptr(), + inputs[1].dptr(), inputs[2].dptr()); + + Kernel::Launch(s, inputs[0].Size(), + outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), + inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), + actual_min_float.dptr_, actual_max_float.dptr_); + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_QUANTIZATION_REQUANTIZE_INL_H_ diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc new file mode 100644 index 000000000000..83ea37b835cf --- /dev/null +++ b/src/operator/quantization/requantize.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file requantize.cc + * \brief + */ +#include "./requantize-inl.h" +#include "./quantize-inl.h" + +namespace mxnet { +namespace op { +DMLC_REGISTER_PARAMETER(RequantizeParam); + +NNVM_REGISTER_OP(_contrib_requantize) +.describe(R"code(Given data that is quantized in int32 and the corresponding thresholds, +requantize the data into int8 using min and max thresholds either calculated at runtime +or from calibration. It's highly recommended to pre-calucate the min and max thresholds +through calibration since it is able to save the runtime of the operator and improve the +inference accuracy. + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training.)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(3) +.set_num_outputs(3) +.set_attr("FInferShape", QuantizeShape) +.set_attr("FInferType", RequantizeType) +.set_attr("FCompute", RequantizeForward) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) { + const RequantizeParam& param = + nnvm::get(attrs.parsed); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + return std::vector(); + } else { + return std::vector(1, ResourceRequest::kTempSpace); + } + }) +.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `int32`") +.add_argument("min_range", "NDArray-or-Symbol", "The original minimum scalar value " + "in the form of float32 used for quantizing data into int32.") +.add_argument("max_range", "NDArray-or-Symbol", "The original maximum scalar value " + "in the form of float32 used for quantizing data into int32.") +.add_arguments(RequantizeParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/requantize.cu b/src/operator/quantization/requantize.cu new file mode 100644 index 000000000000..be8ae59124e5 --- /dev/null +++ b/src/operator/quantization/requantize.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantize.cu + * \brief + */ +#include "./requantize-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_requantize) +.set_attr("FCompute", RequantizeForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index 630fef65a52e..b6bb39a19847 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -348,21 +348,21 @@ static inline uint64_t calc_num_load(const int X, const int Y, const int* stride } template -ReduceImplConfig ConfigureReduceImpl(const TBlob& small, const TBlob& big, const TBlob* lhs, - const TBlob* rhs) { +ReduceImplConfig ConfigureReduceImpl(const TShape& small, const TShape& big, const TShape* lhs, + const TShape* rhs) { ReduceImplConfig config; - diff(small.shape_.get(), big.shape_.get(), &config.rshape, &config.rstride); - config.N = small.shape_.Size(); + diff(small.get(), big.get(), &config.rshape, &config.rstride); + config.N = small.Size(); config.M = config.rshape.Size(); bool multiOp = false; if (lhs != NULL) { CHECK_NOTNULL(rhs); - diff(small.shape_.get(), lhs->shape_.get(), &config.lhs_shape, + diff(small.get(), lhs->get(), &config.lhs_shape, &config.lhs_stride); - diff(small.shape_.get(), rhs->shape_.get(), &config.rhs_shape, + diff(small.get(), rhs->get(), &config.rhs_shape, &config.rhs_stride); multiOp = true; } @@ -376,20 +376,20 @@ ReduceImplConfig ConfigureReduceImpl(const TBlob& small, const TBlob& big, } else { int reduce_strides[3]; - reduce_strides[0] = fastest_stride(small.shape_.get(), big.shape_.get(), - big.shape_.get()); - reduce_strides[1] = (multiOp) ? fastest_stride(small.shape_.get(), - lhs->shape_.get(), lhs->shape_.get()) : 1; - reduce_strides[2] = (multiOp) ? fastest_stride(small.shape_.get(), - rhs->shape_.get(), rhs->shape_.get()) : 1; + reduce_strides[0] = fastest_stride(small.get(), big.get(), + big.get()); + reduce_strides[1] = (multiOp) ? fastest_stride(small.get(), + lhs->get(), lhs->get()) : 1; + reduce_strides[2] = (multiOp) ? fastest_stride(small.get(), + rhs->get(), rhs->get()) : 1; int reduce_strides_transp[3]; - reduce_strides_transp[0] = fastest_stride(small.shape_.get(), config.rshape, + reduce_strides_transp[0] = fastest_stride(small.get(), config.rshape, config.rstride); reduce_strides_transp[1] = (multiOp) ? - fastest_stride(small.shape_.get(), config.lhs_shape, config.lhs_stride) : 1; + fastest_stride(small.get(), config.lhs_shape, config.lhs_stride) : 1; reduce_strides_transp[2] = (multiOp) ? - fastest_stride(small.shape_.get(), config.rhs_shape, config.rhs_stride) : 1; + fastest_stride(small.get(), config.rhs_shape, config.rhs_stride) : 1; uint64_t num_load = calc_num_load(config.N, config.M, reduce_strides); uint64_t num_load_transp = calc_num_load(config.M, config.N, reduce_strides_transp); @@ -597,7 +597,8 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { if (req == kNullOp) return; cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config = ConfigureReduceImpl(small, big, NULL, NULL); + ReduceImplConfig config = + ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); ReduceImpl(stream, small, req, big, workspace, config); } @@ -607,21 +608,22 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, const TBlob& lhs, const TBlob& rhs) { if (req == kNullOp) return; cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config = ConfigureReduceImpl(small, big, &lhs, &rhs); + ReduceImplConfig config = + ConfigureReduceImpl(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_); ReduceImpl(stream, small, lhs, rhs, req, big, workspace, config); } template -size_t ReduceWorkspaceSize(Stream *s, const TBlob& small, const OpReqType req, - const TBlob& big) { +size_t ReduceWorkspaceSize(Stream *s, const TShape& small, const OpReqType req, + const TShape& big) { if (req == kNullOp) return 0; ReduceImplConfig config = ConfigureReduceImpl(small, big, NULL, NULL); return config.workspace_size; } template -size_t ReduceWorkspaceSize(Stream *s, const TBlob& small, const OpReqType req, - const TBlob& big, const TBlob& lhs, const TBlob& rhs) { +size_t ReduceWorkspaceSize(Stream *s, const TShape& small, const OpReqType req, + const TShape& big, const TShape& lhs, const TShape& rhs) { if (req == kNullOp) return 0; ReduceImplConfig config = ConfigureReduceImpl(small, big, &lhs, &rhs); return config.workspace_size; diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 7f3e5685a086..76ec92a9e724 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -217,14 +217,14 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, } template -size_t ReduceWorkspaceSize(Stream *s, const TBlob& small, const OpReqType req, - const TBlob& big) { +size_t ReduceWorkspaceSize(Stream *s, const TShape& small, const OpReqType req, + const TShape& big) { return 0; } template -size_t ReduceWorkspaceSize(Stream *s, const TBlob& small, const OpReqType req, - const TBlob& big, const TBlob& lhs, const TBlob& rhs) { +size_t ReduceWorkspaceSize(Stream *s, const TShape& small, const OpReqType req, + const TShape& big, const TShape& lhs, const TShape& rhs) { return 0; } diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 02d48b469703..f124ba3021e5 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -421,7 +421,7 @@ void ReduceAxesComputeImpl(const nnvm::NodeAttrs& attrs, const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data, req[0], in_data); + s, out_data.shape_, req[0], in_data.shape_); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); broadcast::Reduce( diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index af5f5ce3af80..a2e63fefad58 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -205,8 +205,10 @@ void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs, const TBlob out = inputs[0].reshape(new_oshape); BROADCAST_NDIM_SWITCH(ndim, NDim, { // Request temporary storage - size_t workspace_size_l = ReduceWorkspaceSize(s, lhs, req[0], out); - size_t workspace_size_r = ReduceWorkspaceSize(s, rhs, req[1], out); + size_t workspace_size_l = ReduceWorkspaceSize( + s, lhs.shape_, req[0], out.shape_); + size_t workspace_size_r = ReduceWorkspaceSize( + s, rhs.shape_, req[1], out.shape_); size_t workspace_size = std::max(workspace_size_l, workspace_size_r); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); @@ -234,8 +236,10 @@ inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx, const TBlob ograd = inputs[0].reshape(new_oshape); const TBlob lhs = inputs[1].reshape(new_lshape); const TBlob rhs = inputs[2].reshape(new_rshape); - size_t workspace_size_l = ReduceWorkspaceSize(s, lgrad, req[0], ograd, lhs, rhs); - size_t workspace_size_r = ReduceWorkspaceSize(s, rgrad, req[1], ograd, lhs, rhs); + size_t workspace_size_l = ReduceWorkspaceSize( + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_); + size_t workspace_size_r = ReduceWorkspaceSize( + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_); size_t workspace_size = std::max(workspace_size_l, workspace_size_r); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 475409e6a774..0c74cac2dca5 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -88,6 +88,7 @@ struct EyeParam : public dmlc::Parameter { .add_enum("float64", mshadow::kFloat64) .add_enum("float16", mshadow::kFloat16) .add_enum("uint8", mshadow::kUint8) + .add_enum("int8", mshadow::kInt8) .add_enum("int32", mshadow::kInt32) .add_enum("int64", mshadow::kInt64) .describe("Target data type."); diff --git a/tests/python/quantization/common.py b/tests/python/quantization/common.py new file mode 120000 index 000000000000..dccb90b10675 --- /dev/null +++ b/tests/python/quantization/common.py @@ -0,0 +1 @@ +../unittest/common.py \ No newline at end of file diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py new file mode 100644 index 000000000000..7b08f46e836a --- /dev/null +++ b/tests/python/quantization/test_quantization.py @@ -0,0 +1,447 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Some of the tests using CUDNN require a special GPU instruction called dp4a. +Ref: http://images.nvidia.com/content/pdf/tesla/184457-Tesla-P4-Datasheet-NV-Final-Letter-Web.pdf +""" +import mxnet as mx +import numpy as np +from mxnet.test_utils import assert_almost_equal, rand_ndarray, rand_shape_nd, same, DummyIter +from common import with_seed +from mxnet.module import Module +from mxnet.io import NDArrayIter + + +@with_seed() +def test_quantize_float32_to_int8(): + shape = rand_shape_nd(4) + data = rand_ndarray(shape, 'default', dtype='float32') + min_range = mx.nd.min(data) + max_range = mx.nd.max(data) + qdata, min_val, max_val = mx.nd.contrib.quantize(data, min_range, max_range, out_type='int8') + data_np = data.asnumpy() + min_range = min_range.asscalar() + max_range = max_range.asscalar() + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + quantized_range = 127.0 + scale = quantized_range / real_range + assert qdata.dtype == np.int8 + assert min_val.dtype == np.float32 + assert max_val.dtype == np.float32 + assert same(min_val.asscalar(), -real_range) + assert same(max_val.asscalar(), real_range) + qdata_np = (np.sign(data_np) * np.minimum(np.abs(data_np) * scale + 0.5, quantized_range)).astype(np.int8) + assert same(qdata.asnumpy(), qdata_np) + + +@with_seed() +def test_dequantize_int8_to_float32(): + shape = rand_shape_nd(4) + qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8) + qdata = mx.nd.array(qdata_np, dtype=np.int8) + real_range = 402.3347 + min_range = mx.nd.array([-real_range], dtype=np.float32) + max_range = mx.nd.array([real_range], dtype=np.float32) + data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') + quantized_range = 127.0 + scale = real_range / quantized_range + assert data.dtype == np.float32 + data_np = qdata_np * scale + assert_almost_equal(data.asnumpy(), data_np) + + +@with_seed() +def test_requantize_int32_to_int8(): + def quantized_int32_to_float(qdata, min_range, max_range): + assert qdata.dtype == 'int32' + quantized_range = np.iinfo('int32').max + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + scale = float(real_range) / float(quantized_range) + return qdata.astype('float32') * scale + + def float_to_quantized_int8(data, min_range, max_range): + assert data.dtype == 'float32' + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + quantized_range = np.iinfo('int8').max + scale = float(quantized_range) / float(real_range) + return (np.sign(data) * np.minimum(np.abs(data) * scale + 0.5, quantized_range)).astype('int8') + + def requantize(qdata, min_data, max_data, real_range): + data = quantized_int32_to_float(qdata, min_data, max_data) + output = float_to_quantized_int8(data, -real_range, real_range) + return output, -real_range, real_range + + def requantize_baseline(qdata, min_data, max_data, min_calib_range=None, max_calib_range=None): + if min_calib_range is not None and max_calib_range is not None: + real_range = np.maximum(np.abs(min_calib_range), np.abs(max_calib_range)) + return requantize(qdata, min_data, max_data, real_range) + else: + min_range = quantized_int32_to_float(np.min(qdata), min_data, max_data) + max_range = quantized_int32_to_float(np.max(qdata), min_data, max_data) + return requantize(qdata, min_data, max_data, np.maximum(np.abs(min_range), np.abs(max_range))) + + def check_requantize(shape, min_calib_range=None, max_calib_range=None): + qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32') + min_range = mx.nd.array([-1010.0]) + max_range = mx.nd.array([1020.0]) + if min_calib_range is None or max_calib_range is None: + qdata_int8, min_output, max_output = mx.nd.contrib.requantize(qdata, min_range, max_range) + else: + qdata_int8, min_output, max_output = mx.nd.contrib.requantize(qdata, min_range, max_range, + min_calib_range, max_calib_range) + + qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(), + max_range.asscalar(), + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np) + assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) + assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) + + check_requantize((3, 4, 10, 10)) + check_requantize((32, 3, 23, 23)) + check_requantize((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) + check_requantize((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43) + + +@with_seed() +def test_quantized_conv(): + if mx.current_context().device_type != 'gpu': + print('skipped testing quantized_conv on cpu since it is not implemented yet') + return + + def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, no_bias): + with mx.Context('gpu', 0): + # run fp32 conv + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + conv2d = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride, + no_bias=no_bias, cudnn_off=False, name='conv2d') + arg_shapes, _, _ = conv2d.infer_shape(data=data_shape) + arg_names = conv2d.list_arguments() + conv_exe_fp32 = conv2d.simple_bind(ctx=mx.current_context(), grad_req='null') + conv_exe_fp32.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=data_shape).astype('int32') + conv_exe_fp32.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[1]).astype('int32') + if not no_bias: + conv_exe_fp32.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[2]).astype('int32') + output = conv_exe_fp32.forward()[0] + + # run quantized conv + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8') + qweight = mx.sym.Variable(name='qweight', dtype='int8') + min_data = mx.sym.Variable(name='min_data') + max_data = mx.sym.Variable(name='max_data') + min_weight = mx.sym.Variable(name='min_weight') + max_weight = mx.sym.Variable(name='max_weight') + quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, weight=qweight, min_data=min_data, + max_data=max_data, min_weight=min_weight, + max_weight=max_weight, kernel=kernel, + num_filter=num_filter, pad=pad, stride=stride, + no_bias=no_bias) + qarg_names = quantized_conv2d.list_arguments() + type_dict = None + if not no_bias: + type_dict = {qarg_names[2]: 'int8'} + conv_exe_int8 = quantized_conv2d.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null') + conv_exe_int8.arg_dict[qarg_names[0]][:] = conv_exe_fp32.arg_dict[arg_names[0]].astype('int8') + conv_exe_int8.arg_dict[qarg_names[1]][:] = conv_exe_fp32.arg_dict[arg_names[1]].astype('int8') + quantized_range = 127.0 + if no_bias: + conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range + else: + conv_exe_int8.arg_dict[qarg_names[2]][:] = conv_exe_fp32.arg_dict[arg_names[2]].astype('int8') + conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range + qoutput, min_range, max_range = conv_exe_int8.forward() + + if no_bias: + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + else: + # with adding bias, accuracy loss should not be greater than one + diff = mx.nd.abs(output - qoutput.astype(output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), True) + check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), False) + + +@with_seed() +def test_quantized_pooling(): + if mx.current_context().device_type != 'gpu': + print('skipped testing quantized_pooling on cpu since it is not implemented yet') + return + + def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool): + with mx.Context('gpu', 0): + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride, + pool_type=pool_type, global_pool=global_pool, cudnn_off=False) + arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape) + arg_names = pooling_fp32.list_arguments() + pooling_fp32_exe = pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') + pooling_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=data_shape).astype('int32') + output = pooling_fp32_exe.forward()[0] + + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8') + min_data = mx.sym.Variable(name='min_data') + max_data = mx.sym.Variable(name='max_data') + quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, min_data=min_data, + max_data=max_data, kernel=kernel, + pad=pad, stride=stride, pool_type=pool_type, + global_pool=global_pool) + pooling_int8_exe = quantized_pooling.simple_bind(ctx=mx.current_context(), grad_req='null') + qarg_names = quantized_pooling.list_arguments() + pooling_int8_exe.arg_dict[qarg_names[0]][:] = pooling_fp32_exe.arg_dict[arg_names[0]].astype('int8') + quantized_range = 127.0 + pooling_int8_exe.arg_dict[qarg_names[1]][:] = -quantized_range + pooling_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range + qoutput, min_range, max_range = pooling_int8_exe.forward() + + if pool_type == 'max': + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + elif pool_type == 'avg': # for avg pooling, fp32 and int8 may be different due to rounding errors + diff = mx.nd.abs(output - qoutput.astype(output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False) + check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True) + check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False) + check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True) + + +@with_seed() +def test_quantized_fc(): + if mx.current_context().device_type != 'gpu': + print('skipped testing quantized_fc on cpu since it is not implemented yet') + return + + def check_quantized_fc(data_shape, num_hidden, no_bias, flatten=True): + with mx.Context('gpu', 0): + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + fc_fp32 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten) + arg_shapes, _, _ = fc_fp32.infer_shape(data=data_shape) + arg_names = fc_fp32.list_arguments() + fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') + fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=data_shape).astype('int32') + fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[1]).astype('int32') + if not no_bias: + fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[2]).astype('int32') + output = fc_fp32_exe.forward()[0] + + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8') + fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, num_hidden=num_hidden, + no_bias=no_bias, flatten=flatten) + qarg_names = fc_int8.list_arguments() + type_dict = {qarg_names[1]: 'int8'} + if not no_bias: + type_dict.update({qarg_names[2]: 'int8'}) + fc_int8_exe = fc_int8.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null') + fc_int8_exe.arg_dict[qarg_names[0]][:] = fc_fp32_exe.arg_dict[arg_names[0]].astype('int8') + fc_int8_exe.arg_dict[qarg_names[1]][:] = fc_fp32_exe.arg_dict[arg_names[1]].astype('int8') + quantized_range = 127.0 + if no_bias: + fc_int8_exe.arg_dict[qarg_names[2]][:] = -quantized_range + fc_int8_exe.arg_dict[qarg_names[3]][:] = quantized_range + fc_int8_exe.arg_dict[qarg_names[4]][:] = -quantized_range + fc_int8_exe.arg_dict[qarg_names[5]][:] = quantized_range + else: + fc_int8_exe.arg_dict[qarg_names[2]][:] = fc_fp32_exe.arg_dict[arg_names[2]].astype('int8') + fc_int8_exe.arg_dict[qarg_names[3]][:] = -quantized_range + fc_int8_exe.arg_dict[qarg_names[4]][:] = quantized_range + fc_int8_exe.arg_dict[qarg_names[5]][:] = -quantized_range + fc_int8_exe.arg_dict[qarg_names[6]][:] = quantized_range + fc_int8_exe.arg_dict[qarg_names[7]][:] = -quantized_range + fc_int8_exe.arg_dict[qarg_names[8]][:] = quantized_range + qoutput, min_range, max_range = fc_int8_exe.forward() + + if no_bias: + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + else: + # with adding bias, accuracy loss should not be greater than one + diff = mx.nd.abs(output - qoutput.astype(output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + check_quantized_fc((32, 512, 2, 2), 100, True) + check_quantized_fc((32, 111, 2, 2), 100, True) + check_quantized_fc((32, 512, 2, 2), 100, False) + check_quantized_fc((32, 111, 2, 2), 100, False) + + +@with_seed() +def test_quantized_flatten(): + def check_quantized_flatten(shape): + qdata = mx.nd.random.uniform(low=-127, high=127, shape=shape).astype('int8') + min_data = mx.nd.array([-1023.343], dtype='float32') + max_data = mx.nd.array([2343.324275], dtype='float32') + qoutput, min_output, max_output = mx.nd.contrib.quantized_flatten(qdata, min_data, max_data) + assert qoutput.ndim == 2 + assert qoutput.shape[0] == qdata.shape[0] + assert qoutput.shape[1] == np.prod(qdata.shape[1:]) + assert same(qdata.asnumpy().flatten(), qoutput.asnumpy().flatten()) + assert same(min_data.asnumpy(), min_output.asnumpy()) + assert same(max_data.asnumpy(), max_output.asnumpy()) + + check_quantized_flatten((10,)) + check_quantized_flatten((10, 15)) + check_quantized_flatten((10, 15, 18)) + check_quantized_flatten((3, 4, 23, 23)) + + +@with_seed() +def test_quantize_params(): + data = mx.sym.Variable('data') + conv = mx.sym.Convolution(data, kernel=(1, 1), num_filter=2048, name='conv') + sym = mx.sym.BatchNorm(data=conv, eps=2e-05, fix_gamma=False, momentum=0.9, use_global_stats=False, name='bn') + offline_params = [name for name in sym.list_arguments() + if not name.startswith('data') and not name.endswith('label')] + params = {} + for name in offline_params: + params[name] = mx.nd.uniform(shape=(2, 2)) + qsym = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params) + qparams = mx.contrib.quant._quantize_params(qsym, params) + param_names = params.keys() + qparam_names = qparams.keys() + for name in qparam_names: + if name.startswith('bn'): + assert name in param_names + elif name.startswith('conv'): + assert name not in param_names + assert name.find('quantize') != -1 + + +def get_fp32_sym(): + data = mx.sym.Variable('data') + conv = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv') + bn = mx.sym.BatchNorm(data=conv, eps=2e-05, fix_gamma=False, momentum=0.9, use_global_stats=False, name='bn') + act = mx.sym.Activation(data=bn, act_type='relu', name='relu') + pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool') + fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False, + out_grad=False, preserve_shape=False, use_ignore=False, name='softmax') + return sym + + +@with_seed() +def test_quantize_model(): + def check_params(params, qparams, qsym=None): + if qsym is None: + assert len(params) == len(qparams) + for k, v in params.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + else: + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params) + assert len(qparams) == len(qparams_ground_truth) + for k, v in qparams_ground_truth.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + + def check_qsym_calibrated(qsym): + attrs = qsym.attr_dict() + for k, v in attrs.items(): + if k.find('requantize_') != -1: + assert 'min_calib_range' in v + assert 'max_calib_range' in v + + sym = get_fp32_sym() + mod = Module(symbol=sym) + batch_size = 4 + data_shape = (batch_size, 4, 10, 10) + label_shape = (batch_size, 10) + mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) + mod.init_params() + arg_params, aux_params = mod.get_params() + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + calib_mode='none') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + + calib_data = mx.nd.random.uniform(shape=data_shape) + calib_data = NDArrayIter(data=calib_data) + calib_data = DummyIter(calib_data) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + calib_mode='naive', + calib_data=calib_data, + num_calib_examples=20) + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_calibrated(qsym) + + +@with_seed() +def test_quantize_sym_with_calib(): + sym = get_fp32_sym() + offline_params = [name for name in sym.list_arguments() + if not name.startswith('data') and not name.endswith('label')] + qsym = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params) + requantize_op_names = ['requantize_conv', 'requantize_fc'] + th_dict = {'conv_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0)), + 'fc_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0))} + op_name_to_th_name = {'requantize_conv': 'conv_output', 'requantize_fc': 'fc_output'} + cqsym = mx.contrib.quant._calibrate_quantized_sym(qsym, th_dict) + attr_dict = cqsym.attr_dict() + for name in requantize_op_names: + assert name in attr_dict + lhs = float(attr_dict[name]['min_calib_range']) + rhs = th_dict[op_name_to_th_name[name]][0] + assert_almost_equal(np.array([lhs]), np.array([rhs])) + lhs = float(attr_dict[name]['max_calib_range']) + rhs = th_dict[op_name_to_th_name[name]][1] + assert_almost_equal(np.array([lhs]), np.array([rhs]), rtol=1e-3, atol=1e-4) + + +@with_seed() +def test_get_optimal_thresholds(): + # Given an ndarray with elements following a uniform distribution, the optimal threshold + # for quantizing the ndarray should be either abs(min(nd)) or abs(max(nd)). + def get_threshold(nd): + min_nd = mx.nd.min(nd) + max_nd = mx.nd.max(nd) + return mx.nd.maximum(mx.nd.abs(min_nd), mx.nd.abs(max_nd)).asnumpy() + + nd_dict = {'layer1': mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23))} + expected_threshold = get_threshold(nd_dict['layer1']) + th_dict = mx.contrib.quant._get_optimal_thresholds(nd_dict) + assert 'layer1' in th_dict + assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold, rtol=0.001, atol=0.001) + + +if __name__ == "__main__": + import nose + nose.runmodule() diff --git a/tests/python/quantization_gpu/test_quantization_gpu.py b/tests/python/quantization_gpu/test_quantization_gpu.py new file mode 100644 index 000000000000..4f2d70effd49 --- /dev/null +++ b/tests/python/quantization_gpu/test_quantization_gpu.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import sys +import mxnet as mx + + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../quantization')) +from mxnet.test_utils import set_default_context +from test_quantization import * + +set_default_context(mx.gpu(0)) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 240c06a5d7a2..2486be04a52f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2539,7 +2539,7 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): for req in ['write', 'add']: check_numeric_gradient(out_s, {'data': data, 'gamma': gamma, 'beta': beta}, grad_nodes={'data': req, 'gamma': req, 'beta': req}, - numeric_eps=1e-2, rtol=1e-2, atol=1e-3) + numeric_eps=1e-2, rtol=1e-2, atol=1e-2) def test_layer_norm(): for dtype, forward_check_eps in zip([np.float16, np.float32, np.float64],