Skip to content

Commit

Permalink
[MXNET-133] Model Quantization with Calibration (apache#9552)
Browse files Browse the repository at this point in the history
* [Quantization] 8bit Quantization and GPU Support

[Quantization] CuDNN 8bit quantized relu v0.1

[Quantization] CuDNN 8bit quantized max_pool v0.1

[Quantization] CuDNN 8bit quantized lrn v0.1

[Quantization] CuDNN 8bit quantized convolution v0.1

[Quantization] CuDNN 8bit quantized fully connected v0.1

[Quantization] Small fix

[Quantization] Implement backward method

[Quantization] Convolution backward method

[Quantization] Add range for matmul and conv

[Quantization] New types in ndarray.py

[Quantization] 8bit conv works

[Quantization] conv support multiple type

[Quantization] matmul works now

[Quantization] matmul works well

[Quantization] efactor quantization operators

[Quantization] Op: quantize_down_and_shrink_range

[Quantization] Complete quantize_graph_pass

[Quantization] Add example

[Quantization] Take zero-center quantize, accuracy fixed

[Quantization] Multiple layers MLP pass

[Quantization] Make quantized_conv same as Convolution

[Quantization] quantized_conv works

[Quantization] Fix bug

[Quantization] lenet works now

[Quantization] Add quantized_flatten

[Quantization] Quantized max pool works well

[Quantization] Make quantized_conv support NHWC

[Quantization] add max_pool

[Quantization] add ignore_symbols

[Quantization] Save change

[Quantization] Reorganize tests, 8 layers resnet works on cifar

[Quantization] Support for 'NHWC' max pool

[Quantization] Support for 'NHWC' quantized max pool

[Quantization] Fix speed of quantize_down_and_shrink_range

[Quantization] script for resnet on imagenet

[Quantization] refactor for quantize offline

[Quantization] Fix infershape

[Quantization] Update test

[Quantization] Update example

[Quantization] Fix build error

* [Quantization] Add calibration flow and refactor code

Rebase with dmlc/master

Add quantize_down_and_shrink by threshold

Don't assign resource when threshold is available for quantize_down_and_shrink

Fix quantize_down_and_shrink saturation

Implement pass for setting calib table to node attrs

Rebase with upstream master

Change threshold to min/max quantized params

Add c-api for setting calib table to graph

Add calibration front end function

Bug fixes and add unit test

Add data iter type to calibration

Fix bug in calibrate_quantized_model

Bug fix and add example

Add the second calibration approach and benchmark

Fix

Fix infer error and add benchmark for conv

Add benchmark script

Change output names and argument names

Remove commented out code

Change name

Add layout to benchmark_convolution

Remove redundant comment

Remove common and add soft link

More fix and benchmark

Add scripts to plot images

Minor fix

More fix

More fix and util tools

Tools and support bias in quantized_conv2d

Add script for getting the optimal thresholds using kl divergence

Add kl divergence for optimizing thresholds

Add benchmark scripts

Fix compile after rebasing on master

Allocate temp space only once for quantized_conv2d

Change quantize_down_and_shrink_range to allocate temp space once

No temp space for calib model

Refactor quantize_down_and_shrink_range into requantize

Refactor quantized convolution using nnvm interfaces

Fix quantized_conv bug

Use ConvolutionParam for QuantizedCuDNNConvOp

Refactor quantized fc using nnvm interfaces

Change TQuantizationNeedShrink to FNeedRequantize

Refactor quantized_pooling

Simplify FQuantizedOp interface

Better naming

Fix shape and type inference for quantized_flatten

Clean up quantization frontend APIs and examples

Delete quantized lrn and relu

Add python script for generating quantized models

Add script for running inference

Add inference example

Remove redundant files from example/quantization

Simplify user-level python APIs

Add logger

Improve user-level python api

Fix coding style

Add unit test for quantized_conv

Fix bugs in quantized_fully_connected and add unit test

Add unit test for requantize

Fix a bug and add python api unit tests

Import test_quantization in test_operator_gpu.py

Rebase with master

Remove redundant files

Fix test case for python3 and fix doc

Fix unit tests

Fix unit tests for python3

Release used ndarrays in calibration for saving memory usage

Simplify releasing memory of used ndarrays for calibration

Fix a bug

Revert "Fix a bug"

This reverts commit f7853f2.

Revert "Simplify releasing memory of used ndarrays for calibration"

This reverts commit 70b9e38.

Clean up benchmark script and improve example

Add API and example documentation and fix bugs

Remove redundant test file and improve error message

Merge quantize and dequantize with master impl

Remove commented code

Hide monitor interface from users

Remove interface from Module

Add license header

Move quantization unittests to a separate folder so that it can be only run on P3 instances

Remove quantization unittests from test_operator_gpu.py

Move quantization to contrib

Fix lint

Add mxnetlinux-gpu-p3 to jenkins

Fix jenkins

Fix CI build

Fix CI

Update jenkins file

Use cudnn7 for ci

Add docker file for quantization unit test only

Correctly skip build with cudnn < 6

Add doc for quantize symbol api

Fix lint

Fix python3 and add doc

Try to fix cudnn build problem

* Fix compile error

* Fix CI

* Remove tests that should not run on P3

* Remove unnecessary docker file

* Fix registering quantized nn ops

* Reformat Jenkinsfile and switch quantization to CUDA 9 (#9)

* Address interface change cr

* Address comments and fix bugs

* Make unit test stable

* Improve unit test

* Address cr

* Address cr

* Fix flaky unit test layer_norm

* Fix doc
  • Loading branch information
reminisce authored and zheng-da committed Jun 28, 2018
1 parent 6c1896b commit 7a069ea
Show file tree
Hide file tree
Showing 49 changed files with 3,883 additions and 116 deletions.
18 changes: 18 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,24 @@ try {
}
}
},
'Python2: Quantize GPU': {
node('mxnetlinux-gpu-p3') {
ws('workspace/ut-python2-quantize-gpu') {
init_git()
unpack_lib('gpu', mx_lib)
sh "ci/build.py --nvidiadocker --build --platform ubuntu_gpu /work/runtime_functions.sh unittest_ubuntu_python2_quantization_gpu"
}
}
},
'Python3: Quantize GPU': {
node('mxnetlinux-gpu-p3') {
ws('workspace/ut-python3-quantize-gpu') {
init_git()
unpack_lib('gpu', mx_lib)
sh "ci/build.py --nvidiadocker --build --platform ubuntu_gpu /work/runtime_functions.sh unittest_ubuntu_python3_quantization_gpu"
}
}
},
'Python2: MKLDNN-CPU': {
node('mxnetlinux-cpu') {
ws('workspace/ut-python2-mkldnn-cpu') {
Expand Down
90 changes: 90 additions & 0 deletions benchmark/python/quantization/benchmark_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import time
import mxnet as mx
from mxnet.test_utils import check_speed


def quantize_int8_helper(data):
min_data = mx.nd.min(data)
max_data = mx.nd.max(data)
return mx.nd.contrib.quantize(data, min_data, max_data, out_type='int8')


def benchmark_convolution(data_shape, kernel, num_filter, pad, stride, no_bias=True, layout='NCHW', repeats=20):
ctx_gpu = mx.gpu(0)
data = mx.sym.Variable(name="data", shape=data_shape, dtype='float32')
# conv cudnn
conv_cudnn = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
no_bias=no_bias, layout=layout, cudnn_off=False, name="conv_cudnn")
arg_shapes, _, _ = conv_cudnn.infer_shape(data=data_shape)
input_data = mx.nd.random.normal(0, 0.2, shape=data_shape, ctx=ctx_gpu)
conv_weight_name = conv_cudnn.list_arguments()[1]
args = {data.name: input_data, conv_weight_name: mx.random.normal(0, 1, shape=arg_shapes[1], ctx=ctx_gpu)}
conv_cudnn_time = check_speed(sym=conv_cudnn, location=args, ctx=ctx_gpu, N=repeats,
grad_req='null', typ='forward') * 1000

# quantized_conv2d
qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
weight = mx.sym.Variable(name='weight', shape=arg_shapes[1], dtype='int8')
min_data = mx.sym.Variable(name='min_data', shape=(1,), dtype='float32')
max_data = mx.sym.Variable(name='max_data', shape=(1,), dtype='float32')
min_weight = mx.sym.Variable(name='min_weight', shape=(1,), dtype='float32')
max_weight = mx.sym.Variable(name='max_weight', shape=(1,), dtype='float32')
quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, weight=weight, min_data=min_data, max_data=max_data,
min_weight=min_weight, max_weight=max_weight,
kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
no_bias=no_bias, layout=layout, cudnn_off=False,
name='quantized_conv2d')
qargs = {qdata.name: quantize_int8_helper(input_data)[0],
min_data.name: quantize_int8_helper(input_data)[1],
max_data.name: quantize_int8_helper(input_data)[2],
weight.name: quantize_int8_helper(args[conv_weight_name])[0],
min_weight.name: quantize_int8_helper(args[conv_weight_name])[1],
max_weight.name: quantize_int8_helper(args[conv_weight_name])[2]}
qconv_time = check_speed(sym=quantized_conv2d, location=qargs, ctx=ctx_gpu, N=repeats,
grad_req='null', typ='forward') * 1000

print('==================================================================================================')
print('data=%s, kernel=%s, num_filter=%s, pad=%s, stride=%s, no_bias=%s, layout=%s, repeats=%s'
% (data_shape, kernel, num_filter, pad, stride, no_bias, layout, repeats))
print('%s , ctx=%s, time=%.2f ms' % (conv_cudnn.name + '-FP32', ctx_gpu, conv_cudnn_time))
print('%s, ctx=%s, time=%.2f ms' % (quantized_conv2d.name, ctx_gpu, qconv_time))
print('quantization speedup: %.1fX' % (conv_cudnn_time / qconv_time))
print('\n')


if __name__ == '__main__':
for batch_size in [32, 64, 128]:
benchmark_convolution(data_shape=(batch_size, 64, 56, 56), kernel=(1, 1), num_filter=256,
pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20)

benchmark_convolution(data_shape=(batch_size, 256, 56, 56), kernel=(1, 1), num_filter=64,
pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20)

benchmark_convolution(data_shape=(batch_size, 256, 56, 56), kernel=(1, 1), num_filter=128,
pad=(0, 0), stride=(2, 2), layout='NCHW', repeats=20)

benchmark_convolution(data_shape=(batch_size, 128, 28, 28), kernel=(3, 3), num_filter=128,
pad=(1, 1), stride=(1, 1), layout='NCHW', repeats=20)

benchmark_convolution(data_shape=(batch_size, 1024, 14, 14), kernel=(1, 1), num_filter=256,
pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20)

benchmark_convolution(data_shape=(batch_size, 2048, 7, 7), kernel=(1, 1), num_filter=512,
pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20)
26 changes: 26 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ unittest_ubuntu_python2_cpu() {
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
nosetests-2.7 --verbose tests/python/unittest
nosetests-2.7 --verbose tests/python/train
nosetests-2.7 --verbose tests/python/quantization
}

unittest_ubuntu_python3_cpu() {
Expand All @@ -371,6 +372,7 @@ unittest_ubuntu_python3_cpu() {
#export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
nosetests-3.4 --verbose tests/python/unittest
nosetests-3.4 --verbose tests/python/quantization
}

unittest_ubuntu_python2_gpu() {
Expand All @@ -393,6 +395,30 @@ unittest_ubuntu_python3_gpu() {
nosetests-3.4 --verbose tests/python/gpu
}

# quantization gpu currently only runs on P3 instances
# need to separte it from unittest_ubuntu_python2_gpu()
unittest_ubuntu_python2_quantization_gpu() {
set -ex
export PYTHONPATH=./python/
# MXNET_MKLDNN_DEBUG is buggy and produces false positives
# https://github.com/apache/incubator-mxnet/issues/10026
#export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
nosetests-2.7 --verbose tests/python/quantization_gpu
}

# quantization gpu currently only runs on P3 instances
# need to separte it from unittest_ubuntu_python3_gpu()
unittest_ubuntu_python3_quantization_gpu() {
set -ex
export PYTHONPATH=./python/
# MXNET_MKLDNN_DEBUG is buggy and produces false positives
# https://github.com/apache/incubator-mxnet/issues/10026
#export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
nosetests-3.4 --verbose tests/python/quantization_gpu
}

unittest_ubuntu_cpu_scala() {
set -ex
make scalapkg USE_BLAS=openblas
Expand Down
22 changes: 22 additions & 0 deletions example/quantization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Model Quantization with Calibration Examples
This folder contains examples of quantizing a FP32 model with or without calibration and using the calibrated
quantized for inference. Two pre-trained imagenet models are taken as examples for quantization. One is
[Resnet-152](http://data.mxnet.io/models/imagenet/resnet/152-layers/), and the other one is
[Inception with BatchNorm](http://data.mxnet.io/models/imagenet/inception-bn/). The calibration dataset
is the [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) for testing the pre-trained models.

Here are the details of the four files in this folder.
- `imagenet_gen_qsym.py` This script provides an example of taking FP32 models and calibration dataset to generate
calibrated quantized models. When launched for the first time, the script would download the user-specified model,
either Resnet-152 or Inception,
and calibration dataset into `model` and `data` folders, respectively. The generated quantized models can be found in
the `model` folder.
- `imagenet_inference.py` This script is used for calculating the accuracy of FP32 models or quantized models on the
validation dataset which was downloaded for calibration in `imagenet_gen_qsym.py`.
- `launch_quantize.sh` This is a shell script that generates various quantized models for Resnet-152 and
Inception with BatchNorm with different configurations. Users can copy and paste the command from the script to
the console to run model quantization for a specific configuration.
- `launch_inference.sh` This is a shell script that calculate the accuracies of all the quantized models generated
by invoking `launch_quantize.sh`.

**NOTE**: This example has only been tested on Linux systems.
1 change: 1 addition & 0 deletions example/quantization/common
194 changes: 194 additions & 0 deletions example/quantization/imagenet_gen_qsym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import os
import logging
from common import modelzoo
import mxnet as mx
from mxnet.contrib.quantization import *


def download_calib_dataset(dataset_url, calib_dataset, logger=None):
if logger is not None:
logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset))
mx.test_utils.download(dataset_url, calib_dataset)


def download_model(model_name, logger=None):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if logger is not None:
logger.info('Downloading model %s... into path %s' % (model_name, model_path))
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))


def save_symbol(fname, sym, logger=None):
if logger is not None:
logger.info('Saving symbol into file at %s' % fname)
sym.save(fname)


def save_params(fname, arg_params, aux_params, logger=None):
if logger is not None:
logger.info('Saving params into file at %s' % fname)
save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()})
mx.nd.save(fname, save_dict)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model')
parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--label-name', type=str, default='softmax_label')
parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec',
help='path of the calibration dataset')
parser.add_argument('--image-shape', type=str, default='3,224,224')
parser.add_argument('--data-nthreads', type=int, default=60,
help='number of threads for data decoding')
parser.add_argument('--num-calib-batches', type=int, default=10,
help='number of batches for calibration')
parser.add_argument('--exclude-first-conv', action='store_true', default=True,
help='excluding quantizing the first conv layer since the'
' number of channels is usually not a multiple of 4 in that layer'
' which does not satisfy the requirement of cuDNN')
parser.add_argument('--shuffle-dataset', action='store_true', default=True,
help='shuffle the calibration dataset')
parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
help='shuffling chunk seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--shuffle-seed', type=int, default=48564309,
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--calib-mode', type=str, default='entropy',
help='calibration mode used for generating calibration table for the quantized symbol; supports'
' 1. none: no calibration will be used. The thresholds for quantization will be calculated'
' on the fly. This will result in inference speed slowdown and loss of accuracy'
' in general.'
' 2. naive: simply take min and max values of layer outputs as thresholds for'
' quantization. In general, the inference accuracy worsens with more examples used in'
' calibration. It is recommended to use `entropy` mode as it produces more accurate'
' inference results.'
' 3. entropy: calculate KL divergence of the fp32 output and quantized output for optimal'
' thresholds. This mode is expected to produce the best inference accuracy of all three'
' kinds of quantized models if the calibration dataset is representative enough of the'
' inference dataset.')
args = parser.parse_args()

logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)

logger.info('shuffle_dataset=%s' % args.shuffle_dataset)

calib_mode = args.calib_mode
logger.info('calibration mode set to %s' % calib_mode)

# download calibration dataset
if calib_mode != 'none':
download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset)

# download model
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

# get batch size
batch_size = args.batch_size
logger.info('batch size = %d for calibration' % batch_size)

# get number of batches for calibration
num_calib_batches = args.num_calib_batches
if calib_mode != 'none':
logger.info('number of batches = %d for calibration' % num_calib_batches)

# get number of threads for decoding the dataset
data_nthreads = args.data_nthreads

# get image shape
image_shape = args.image_shape

exclude_first_conv = args.exclude_first_conv
excluded_sym_names = []
if args.model == 'imagenet1k-resnet-152':
rgb_mean = '0,0,0'
calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
or name.find('sc') != -1
or name.find('fc') != -1)
if exclude_first_conv:
excluded_sym_names = ['conv0']
elif args.model == 'imagenet1k-inception-bn':
rgb_mean = '123.68,116.779,103.939'
calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
or name.find('fc') != -1)
if exclude_first_conv:
excluded_sym_names = ['conv_1']
else:
raise ValueError('model %s is not supported in this script' % args.model)

label_name = args.label_name
logger.info('label_name = %s' % label_name)

data_shape = tuple([int(i) for i in image_shape.split(',')])
logger.info('Input data shape = %s' % str(data_shape))

logger.info('rgb_mean = %s' % rgb_mean)
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}

if calib_mode == 'none':
logger.info('Quantizing FP32 model %s' % args.model)
qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, logger=logger)
sym_name = '%s-symbol.json' % (prefix + '-quantized')
save_symbol(sym_name, qsym, logger)
else:
logger.info('Creating ImageRecordIter for reading calibration dataset')
data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=args.shuffle_dataset,
shuffle_chunk_seed=args.shuffle_chunk_seed,
seed=args.shuffle_seed,
**mean_args)

cqsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
ctx=mx.gpu(0), excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, calib_data=data,
num_calib_examples=num_calib_batches * batch_size,
calib_layer=calib_layer, logger=logger)
if calib_mode == 'entropy':
suffix = '-quantized-%dbatches-entropy' % num_calib_batches
elif calib_mode == 'naive':
suffix = '-quantized-%dbatches-naive' % num_calib_batches
else:
raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`'
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
save_symbol(sym_name, cqsym, logger)

param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
Loading

0 comments on commit 7a069ea

Please sign in to comment.