Skip to content

Commit

Permalink
Add calibration flow for quantization (#4)
Browse files Browse the repository at this point in the history
* [Quantization] Add calibration flow

Rebase with dmlc/master

Add quantize_down_and_shrink by threshold

Don't assign resource when threshold is available for quantize_down_and_shrink

Fix quantize_down_and_shrink saturation

Implement pass for setting calib table to node attrs

Rebase with upstream master

Change threshold to min/max quantized params

Add c-api for setting calib table to graph

Add calibration front end function

Bug fixes and add unit test

Add data iter type to calibration

Fix bug in calibrate_quantized_model

Bug fix and add example

Add the second calibration approach and benchmark

Fix

Fix infer error and add benchmark for conv

Add benchmark script

* Change output names and argument names

* Remove commented out code

* Change name

* Add layout to benchmark_convolution

* Remove redundant comment

* Remove common and add soft link
  • Loading branch information
reminisce authored and ZihengJiang committed Oct 5, 2017
1 parent 8cebd6a commit 8d82b2b
Show file tree
Hide file tree
Showing 26 changed files with 706 additions and 82 deletions.
77 changes: 77 additions & 0 deletions benchmark/python/quantization/benchmark_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import mxnet as mx
from mxnet.test_utils import check_speed


def benchmark_convolution(data_shape, kernel, num_filter, pad, stride, no_bias=True, layout='NCHW', repeats=20):
ctx_gpu = mx.gpu(0)
data = mx.sym.Variable(name="data", shape=data_shape)
# conv cudnn
conv_cudnn = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
no_bias=no_bias, layout=layout, cudnn_off=False, name="conv_cudnn")
arg_shapes, _, _ = conv_cudnn.infer_shape(data=data_shape)
input_data = mx.nd.random.normal(0, 0.2, shape=data_shape, ctx=ctx_gpu)
conv_weight_name = conv_cudnn.list_arguments()[1]
args = {data.name: input_data, conv_weight_name: mx.random.normal(0, 1, shape=arg_shapes[1], ctx=ctx_gpu)}
conv_cudnn_time = check_speed(sym=conv_cudnn, location=args, ctx=ctx_gpu, N=repeats,
grad_req='null', typ='forward') * 1000

# quantized_conv2d
qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
weight = mx.sym.Variable(name='weight', shape=arg_shapes[1], dtype='int8')
min_data = mx.sym.Variable(name='min_data', shape=(1,), dtype='float32')
max_data = mx.sym.Variable(name='max_data', shape=(1,), dtype='float32')
min_weight = mx.sym.Variable(name='min_weight', shape=(1,), dtype='float32')
max_weight = mx.sym.Variable(name='max_weight', shape=(1,), dtype='float32')
quantized_conv2d = mx.sym.quantized_conv2d(data=qdata, weight=weight, min_data=min_data, max_data=max_data,
min_weight=min_weight, max_weight=max_weight,
kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
no_bias=no_bias, layout=layout, cudnn_off=False, name='quantized_conv2d')
qargs = {qdata.name: mx.quantization.quantize(input_data)[0],
min_data.name: mx.quantization.quantize(input_data)[1],
max_data.name: mx.quantization.quantize(input_data)[2],
weight.name: mx.quantization.quantize(args[conv_weight_name])[0],
min_weight.name: mx.quantization.quantize(args[conv_weight_name])[1],
max_weight.name: mx.quantization.quantize(args[conv_weight_name])[2]}
qconv_time = check_speed(sym=quantized_conv2d, location=qargs, ctx=ctx_gpu, N=repeats,
grad_req='null', typ='forward') * 1000

print('==================================================================================================')
print('data=%s, kernel=%s, num_filter=%s, pad=%s, stride=%s, no_bias=%s, layout=%s, repeats=%s'
% (data_shape, kernel, num_filter, pad, stride, no_bias, layout, repeats))
print('%s , ctx=%s, time=%.2f ms' % (conv_cudnn.name+'-FP32', ctx_gpu, conv_cudnn_time))
print('%s, ctx=%s, time=%.2f ms' % (quantized_conv2d.name, ctx_gpu, qconv_time))
print('quantization speedup: %.1fX' % (conv_cudnn_time / qconv_time))
print('\n')


if __name__ == '__main__':
for batch_size in [32, 64, 128]:
benchmark_convolution(data_shape=(batch_size, 64, 122, 122), kernel=(7, 7), num_filter=64,
pad=(1, 1), stride=(1, 1), layout='NCHW', repeats=20)
benchmark_convolution(data_shape=(batch_size, 122, 122, 64), kernel=(7, 7), num_filter=64,
pad=(1, 1), stride=(1, 1), layout='NHWC', repeats=20)

benchmark_convolution(data_shape=(batch_size, 64, 56, 56), kernel=(3, 3), num_filter=64,
pad=(3, 3), stride=(2, 2), layout='NCHW', repeats=20)
benchmark_convolution(data_shape=(batch_size, 56, 56, 64), kernel=(3, 3), num_filter=64,
pad=(3, 3), stride=(2, 2), layout='NHWC', repeats=20)

benchmark_convolution(data_shape=(batch_size, 256, 56, 56), kernel=(1, 1), num_filter=256,
pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20)
benchmark_convolution(data_shape=(batch_size, 56, 56, 256), kernel=(1, 1), num_filter=256,
pad=(0, 0), stride=(1, 1), layout='NHWC', repeats=20)

benchmark_convolution(data_shape=(batch_size, 128, 28, 28), kernel=(3, 3), num_filter=128,
pad=(1, 1), stride=(1, 1), layout='NCHW', repeats=20)
benchmark_convolution(data_shape=(batch_size, 28, 28, 128), kernel=(3, 3), num_filter=128,
pad=(1, 1), stride=(1, 1), layout='NHWC', repeats=20)

benchmark_convolution(data_shape=(batch_size, 256, 14, 14), kernel=(3, 3), num_filter=256,
pad=(1, 1), stride=(1, 1), layout='NCHW', repeats=20)
benchmark_convolution(data_shape=(batch_size, 14, 14, 256), kernel=(3, 3), num_filter=256,
pad=(1, 1), stride=(1, 1), layout='NHWC', repeats=20)

benchmark_convolution(data_shape=(batch_size, 2048, 7, 7), kernel=(1, 1), num_filter=2048,
pad=(0, 0), stride=(1, 1), layout='NCHW', repeats=20)
benchmark_convolution(data_shape=(batch_size, 7, 7, 2048), kernel=(1, 1), num_filter=2048,
pad=(0, 0), stride=(1, 1), layout='NHWC', repeats=20)
1 change: 1 addition & 0 deletions example/quantization/common
2 changes: 2 additions & 0 deletions example/quantization/launch_resnet_calib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#! /bin/sh
python resnet_calib.py --model=imagenet1k-resnet-152 --data-val=./data/val-5k-256.rec --gpus=0 --data-nthreads=60
229 changes: 229 additions & 0 deletions example/quantization/resnet_calib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import argparse
from common import modelzoo
import mxnet as mx
import time
import os
import logging
from mxnet.quantization import *


parser = argparse.ArgumentParser(description='score a model on a dataset')
parser.add_argument('--model', type=str, required=True,
help = 'the model name.')
parser.add_argument('--gpus', type=str, default='0')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--rgb-mean', type=str, default='0,0,0')
parser.add_argument('--data-val', type=str, required=True)
parser.add_argument('--image-shape', type=str, default='3,224,224')
parser.add_argument('--data-nthreads', type=int, default=4,
help='number of threads for data decoding')
parser.add_argument('--low-quantile', type=float, default=0)
parser.add_argument('--high-quantile', type=float, default=1)
args = parser.parse_args()

batch_size = args.batch_size
low_quantile = args.low_quantile
high_quantile = args.high_quantile

# number of predicted and calibrated images can be changed
num_predicted_images = batch_size * 2
num_calibrated_images = batch_size * 1

data_nthreads = args.data_nthreads
data_val = args.data_val
gpus = args.gpus
image_shape = args.image_shape
model = args.model
rgb_mean = args.rgb_mean

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)


mean_img = None
label_name = 'softmax_label'


# create data iterator
data_shape = tuple([int(i) for i in image_shape.split(',')])
if mean_img is not None:
mean_args = {'mean_img':mean_img}
elif rgb_mean is not None:
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b':rgb_mean[2]}

data = mx.io.ImageRecordIter(
path_imgrec = data_val,
label_width = 1,
preprocess_threads = data_nthreads,
batch_size = batch_size,
data_shape = data_shape,
label_name = label_name,
rand_crop = False,
rand_mirror = False,
**mean_args)


if isinstance(model, str):
# download model
dir_path = os.path.dirname(os.path.realpath(__file__))
(prefix, epoch) = modelzoo.download_model(
model, os.path.join(dir_path, 'model'))
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
elif isinstance(model, tuple) or isinstance(model, list):
assert len(model) == 3
(sym, arg_params, aux_params) = model
else:
raise TypeError('model type [%s] is not supported' % str(type(model)))

# create module
if gpus == '':
devs = mx.cpu()
else:
devs = [mx.gpu(int(i)) for i in gpus.split(',')]


def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples):
metrics = [mx.metric.create('acc'),
mx.metric.create('top_k_accuracy', top_k = 5)]
if not isinstance(metrics, list):
metrics = [metrics,]
mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name,])
mod.bind(for_training=False,
data_shapes=data.provide_data,
label_shapes=data.provide_label)
mod.set_params(arg_params, aux_params)

tic = time.time()
num = 0
for batch in data:
mod.forward(batch, is_train=False)
for m in metrics:
mod.update_metric(m, batch.label)
num += batch_size
if max_num_examples is not None and num >= max_num_examples:
break

speed = num / (time.time() - tic)

logging.info('Finished with %f images per second', speed)
for m in metrics:
logging.info(m.get())


def advance_data_iter(data_iter, n):
assert n >= 0
if n == 0:
return data_iter
has_next_batch = True
while has_next_batch:
try:
data_iter.next()
n -= 1
if n == 0:
return data_iter
except StopIteration:
has_next_batch = False


print('\n\n')

#################################################################################################
print('====================================================================\n')
print('Running FP32 model for inference...')
data.reset()
# make sure that fp32 inference works on the same images as calibrated quantized model
data = advance_data_iter(data, num_calibrated_images/batch_size)
score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples=num_predicted_images)
data.reset()
print('Finished running FP32 model for inference')
print('\n')
#################################################################################################

#################################################################################################
print('====================================================================\n')
# cudnn int8 convolution only support channels a multiple of 4
# have to ignore quantizing conv0 node
ignore_symbols = []
ignore_sym_names = ['conv0']
for name in ignore_sym_names:
nodes = sym.get_internals()
idx = nodes.list_outputs().index(name + '_output')
ignore_symbols.append(nodes[idx])

print('Quantizing the FP32 model...')
qsym = quantize_graph(sym, ignore_symbols=ignore_symbols, offline_params=arg_params.keys())
print('Finished quantizing the FP32 model')
print('Quantizing parameters of the FP32 model...')
qarg_params = quantize_params(qsym, arg_params)
print('Finished quantizing the parameters of the FP32 model')

print('Running quantized model (INT8) for inference...')
data.reset()
# make sure that int8 uncalibrated inference works on the same images as calibrated quantized model
data = advance_data_iter(data, num_calibrated_images/batch_size)
score(qsym, qarg_params, aux_params, data, devs, label_name, max_num_examples=num_predicted_images)
data.reset()
print('Finished running quantized model (INT8) for inference')
print('\n')
#################################################################################################

#################################################################################################
print('====================================================================\n')
# calibrate model by collecting quantiles from quantized model outputs
print('Collecting quantiles from quantized model outputs...')
include_layer = lambda name: name.startswith('quantized_') and name.endswith('_output')
collector = LayerOutputQuantileCollector(low_quantile=low_quantile,
high_quantlie=high_quantile,
include_layer=include_layer)
mod = mx.mod.Module(symbol=qsym, context=devs, label_names=[label_name, ])
mod.bind(for_training=False,
data_shapes=data.provide_data,
label_shapes=data.provide_label)
mod.set_params(qarg_params, aux_params)
data.reset()
quantile_dict = mx.quantization.collect_layer_output_quantiles(mod, data, collector,
max_num_examples=num_calibrated_images)
data.reset()
data = advance_data_iter(data, num_calibrated_images/batch_size)
print('Finished collecting quantiles from quantized model outputs')
print('Calibrating quantized model using INT32 quantiles...')
calib_table_type = 'int32'
cqsym = mx.quantization.calibrate_quantized_sym(qsym, quantile_dict, calib_table_type)
print('Finished calibrating quantized model')
print('Running calibrated quantized model (INT32 calibration table) for inference...')
score(cqsym, qarg_params, aux_params, data, devs, label_name, max_num_examples=num_predicted_images)
data.reset()
print('Finished running calibrated quantized model (INT32 calibration table) for inference')
print('\n')
#################################################################################################

#################################################################################################
print('====================================================================\n')
# calibrate model by collecting quantiles from FP32 model outputs
print('Collecting quantiles from FP32 model outputs...')
include_layer = lambda name: name.endswith('_output')
collector = LayerOutputQuantileCollector(low_quantile=low_quantile,
high_quantlie=high_quantile,
include_layer=include_layer)
mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name,])
mod.bind(for_training=False,
data_shapes=data.provide_data,
label_shapes=data.provide_label)
mod.set_params(arg_params, aux_params)
data.reset()
quantile_dict = mx.quantization.collect_layer_output_quantiles(mod, data, collector,
max_num_examples=num_calibrated_images)
data.reset()
data = advance_data_iter(data, num_calibrated_images/batch_size)
print('Finished collecting quantiles from FP32 model outputs...')
print('Calibrating quantized model using FP32 quantiles...')
calib_table_type = 'float32'
cqsym = mx.quantization.calibrate_quantized_sym(qsym, quantile_dict, calib_table_type)
print('Finished calibrating quantized model using FP32 quantiles')
print('Running calibrated quantized model (FP32 calibration table) for inference...')
score(cqsym, qarg_params, aux_params, data, devs, label_name, max_num_examples=num_predicted_images)
print('Finished running calibrated quantized model (FP32 calibration table) for inference')
data.reset()
print('\n')
#################################################################################################
23 changes: 21 additions & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1194,12 +1194,31 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
const int **aux_type_data,
int *complete);

MXNET_DLL int MXQuantizeGraph(SymbolHandle sym,
SymbolHandle *ret_sym,
MXNET_DLL int MXQuantizeGraph(SymbolHandle sym_handle,
SymbolHandle *ret_sym_handle,
mx_uint num_ignore,
SymbolHandle *ignore_symbols,
mx_uint num_offline,
const char **offline_params);

/*!
* \brief Set calibration table to node attributes in the sym
* \param sym_handle symbol whose node attributes are to be set by calibration table
* \param calib_table_type calibration table data type, either "int32" or "float32"
* \param num_layers number of layers in the calibration table
* \param layer names stored as keys in the calibration table
* \param low_quantiles low quantiles of layers stored in the calibration table
* \param high_quantiles high quantiles of layers stored in the calibration table
* \param ret_sym_handle returned symbol
*/
MXNET_DLL int MXSetCalibTableToQuantizedGraph(SymbolHandle sym_handle,
const char* calib_table_type,
const mx_uint num_layers,
const char** layer_names,
const float* low_quantiles,
const float* high_quantiles,
SymbolHandle* ret_sym_handle);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,10 @@ def install_monitor(self, mon):
"""Installs monitor on all executors. """
assert self.binded
self._exec_group.install_monitor(mon)

def set_monitor_callback(self, cb):
"""Set monitor callback function to executor.
Only supports the module with only one executor."""
assert len(self._exec_group.execs) == 1, 'Module supports setting callback' \
'function for only one executor'
self._exec_group.execs[0].set_monitor_callback(cb)
Loading

0 comments on commit 8d82b2b

Please sign in to comment.