Skip to content

Commit

Permalink
Add Calibtool for GluonCV (#893)
Browse files Browse the repository at this point in the history
* add calibration tool

* add load static model for inference

* remove useless code

* support ssd

* add benchmark result

* add wildcard match for exclude layers

* add quantized concat ssd since it has been fixed in master

* exclude concat

* support fc int8

* add c5.24xlarge 1S perf

* enable dataiter api

* improve script

* improve script and add ut

* exclude concat in ssd to solve accuracy issue

* upgrade mxnet-mkl to 0807

* upgrade to 0808

* add doc

* add eol

* fix doc

* fix typo
  • Loading branch information
xinyu-intel authored Aug 12, 2019
1 parent 3e85b82 commit 059e6de
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 105 deletions.
2 changes: 1 addition & 1 deletion docs/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- pip:
- https://github.com/mli/mx-theme/tarball/0.3.1
- sphinx-gallery
- mxnet-cu92mkl==1.6.0b20190802
- mxnet-cu92mkl==1.6.0b20190808
# - guzzle_sphinx_theme
- recommonmark
- Image
Expand Down
135 changes: 121 additions & 14 deletions docs/tutorials/deployment/int8_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
This is a tutorial which illustrates how to use quantized GluonCV
models for inference on Intel Xeon Processors to gain higher performance.
The following example requires ``GluonCV>=0.4`` and ``MXNet-mkl>=1.5.0b20190623``. Please follow `our installation guide <../../index.html#installation>`__ to install or upgrade GluonCV and nightly build of MXNet if necessary.
The following example requires ``GluonCV>=0.4`` and ``MXNet-mkl>=1.5.0b20190807``. Please follow `our installation guide <../../index.html#installation>`__ to install or upgrade GluonCV and nightly build of MXNet if necessary.
Introduction
------------
Expand Down Expand Up @@ -60,36 +60,143 @@
export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'`
export OMP_NUM_THREADS=$(CPUs)
# with Pascal VOC validation dataset saved on disk
python eval_ssd.py --network=vgg16_atrous --quantized --data-shape=300 --batch-size=224 --dataset=voc --benchmark
python eval_ssd.py --network=mobilenet1.0 --quantized --data-shape=512 --batch-size=224 --dataset=voc --benchmark
Usage:
::
SYNOPSIS
python eval_ssd.py [-h] [--network NETWORK] [--quantized]
python eval_ssd.py [-h] [--network NETWORK] [--deploy]
[--model-prefix] [--quantized]
[--data-shape DATA_SHAPE] [--batch-size BATCH_SIZE]
[--benchmark BENCHMARK] [--num-iterations NUM_ITERATIONS]
[--dataset DATASET] [--num-workers NUM_WORKERS]
[--num-gpus NUM_GPUS] [--pretrained PRETRAINED]
[--save-prefix SAVE_PREFIX]
[--save-prefix SAVE_PREFIX] [--calibration CALIBRATION]
[--num-calib-batches NUM_CALIB_BATCHES]
[--quantized-dtype {auto,int8,uint8}]
[--calib-mode CALIB_MODE]
OPTIONS
-h, --help show this help message and exit
--network NETWORK Base network name
--quantized use int8 pretrained model
-h, --help show this help message and exit
--network NETWORK base network name
--deploy whether load static model for deployment
--model-prefix MODEL_PREFIX
load static model as hybridblock.
--quantized use int8 pretrained model
--data-shape DATA_SHAPE
Input data shape
input data shape
--batch-size BATCH_SIZE
eval mini-batch size
--benchmark BENCHMARK run dummy-data based benchmarking
--benchmark BENCHMARK run dummy-data based benchmarking
--num-iterations NUM_ITERATIONS number of benchmarking iterations.
--dataset DATASET eval dataset.
--dataset DATASET eval dataset.
--num-workers NUM_WORKERS, -j NUM_WORKERS
Number of data workers
--num-gpus NUM_GPUS number of gpus to use.
number of data workers
--num-gpus NUM_GPUS number of gpus to use.
--pretrained PRETRAINED
Load weights from previously saved parameters.
load weights from previously saved parameters.
--save-prefix SAVE_PREFIX
Saving parameter prefix
saving parameter prefix
--calibration quantize model
--num-calib-batches NUM_CALIB_BATCHES
number of batches for calibration
--quantized-dtype {auto,int8,uint8}
quantization destination data type for input data
--calib-mode CALIB_MODE
calibration mode used for generating calibration table
for the quantized symbol; supports 1. none: no
calibration will be used. The thresholds for
quantization will be calculated on the fly. This will
result in inference speed slowdown and loss of
accuracy in general. 2. naive: simply take min and max
values of layer outputs as thresholds for
quantization. In general, the inference accuracy
worsens with more examples used in calibration. It is
recommended to use `entropy` mode as it produces more
accurate inference results. 3. entropy: calculate KL
divergence of the fp32 output and quantized output for
optimal thresholds. This mode is expected to produce
the best inference accuracy of all three kinds of
quantized models if the calibration dataset is
representative enough of the inference dataset.
Calibration Tool
----------------
GluonCV also delivered calibration tool for users to quantize their models into int8 with their own dataset. Currently, calibration tool only supports hybridized gluon models. Below is an example of quantizing SSD model.
.. code:: bash
# Calibration
python eval_ssd.py --network=mobilenet1.0 --data-shape=512 --batch-size=224 --dataset=voc --calibration --num-calib-batches=5 --calib-mode=naive
# INT8 Inference
python eval_ssd.py --network=mobilenet1.0 --data-shape=512 --batch-size=224 --deploy --model-prefix=./model/ssd_512_mobilenet1.0_voc-quantized-naive
The first command will launch naive calibration to quantize your ssd_mobilenet1.0 model to int8 by using a subset (5 batches) of your given dataset. Users can tune the int8 accuracy by setting different calibration configurations. After calibration, quantized model and parameter will be saved on your disk. Then, the second command will load quantized model as a symbolblock for inference.
Users can also quantize their own gluon hybridized model by using `quantize_net` api. Below are some descriptions.
API:
::
CODE
from mxnet.contrib.quantization import *
quantized_net = quantize_net(network, quantized_dtype='auto',
exclude_layers=None, exclude_layers_match=None,
calib_data=None, data_shapes=None,
calib_mode='naive', num_calib_examples=None,
ctx=mx.cpu(), logger=logging)
Parameters
network : Gluon HybridBlock
Defines the structure of a neural network for FP32 data types.
quantized_dtype : str
The quantized destination type for input data. Currently support 'int8'
, 'uint8' and 'auto'.
'auto' means automatically select output type according to calibration result.
Default value is 'int8'.
exclude_layers : list of strings
A list of strings representing the names of the symbols that users want to excluding
exclude_layers_match : list of strings
A list of strings wildcard matching the names of the symbols that users want to excluding
from being quantized.
calib_data : mx.io.DataIter or gluon.DataLoader
A iterable data loading object.
data_shapes : list
List of DataDesc, required if calib_data is not provided
calib_mode : str
If calib_mode='none', no calibration will be used and the thresholds for
requantization after the corresponding layers will be calculated at runtime by
calling min and max operators. The quantized models generated in this
mode are normally 10-20% slower than those with calibrations during inference.
If calib_mode='naive', the min and max values of the layer outputs from a calibration
dataset will be directly taken as the thresholds for quantization.
If calib_mode='entropy', the thresholds for quantization will be
derived such that the KL divergence between the distributions of FP32 layer outputs and
quantized layer outputs is minimized based upon the calibration dataset.
calib_layer : function
Given a layer's output name in string, return True or False for deciding whether to
calibrate this layer. If yes, the statistics of the layer's output will be collected;
otherwise, no information of the layer's output will be collected. If not provided,
all the layers' outputs that need requantization will be collected.
num_calib_examples : int or None
The maximum number of examples that user would like to use for calibration.
If not provided, the whole calibration dataset will be used.
ctx : Context
Defines the device that users want to run forward propagation on the calibration
dataset for collecting layer output statistics. Currently, only supports single context.
Currently only support CPU with MKL-DNN backend.
logger : Object
A logging object for printing information during the process of quantization.
Returns
network : Gluon SymbolBlock
Defines the structure of a neural network for INT8 data types.
"""
35 changes: 35 additions & 0 deletions scripts/classification/imagenet/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,39 @@
# Image Classification on ImageNet

## Inference/Calibration Tutorial

### Float32 Inference

```
python verify_pretrained.py --model=resnet50_v1d_0.11 --batch-size=1
```

### Calibration

Naive calibrate model by using 5 batch data (32 images per batch). Quantized model will be saved into `./model/`.

```
python verify_pretrained.py --model=resnet50_v1d_0.11 --batch-size=32 --calibration
```

### INT8 Inference

```
python verify_pretrained.py --model=resnet50_v1d_0.11 --batch-size=1 --deploy --model-prefix=./model/resnet50_v1d_0.11-quantized-naive
```

## Performance

model | f32 latency(ms) | s8 latency(ms) | f32 throughput(fps, BS=64) | s8 throughput(fps, BS=64) | f32 accuracy | s8 accuracy
-- | -- | -- | -- | -- | -- | --
resnet50_v1 | 11.36 | 2.54 | 190.2 | 1363.75 | 77.21/93.56 | 76.34/93.13
resnet50_v1d_0.11 | 8.84 | 1.74 | 1070.66 | 10686.77 | 63.06/84.64 | 62.68/84.43
mobilenet1.0 | 3.88 | 0.88 | 583.05 | 5615.58 | 73.28/91.22 | 72.23/90.64
mobilenetv2_1.0 | 18.10 | 1.34 | 226.27 | 5005.94 | 71.89/90.53 | 70.87/89.88
squeezenet1.0 | 4.18 | 0.96 | 590.76 | 3393.09 | 57.74/80.33 | 56.98/79.66
squeezenet1.1 | 3.31 | 0.87 | 964.83 | 6027.15 | 58.00/80.47 | 57.02/79.73
inceptionv3 | 20.73 | 4.99 | 156.63 | 917.67 | 78.80/94.37 | 77.36/93.57
vgg16 | 16.71 | 7.63 | 87.17 | 399.62 | 73.06/91.18 | 71.94/90.59

Please refer to [GluonCV Model Zoo](http://gluon-cv.mxnet.io/model_zoo/index.html#image-classification)
for available pretrained models, training hyper-parameters, etc.
Loading

0 comments on commit 059e6de

Please sign in to comment.