Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add int8 data loader #14123

Merged
merged 8 commits into from
Mar 4, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/perl/io.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Then we can call `$mod->fit($nd_iter, num_epoch=>2)` to train `loss` by 2 epochs
mx->io->NDArrayIter
mx->io->CSVIter
mx->io->ImageRecordIter
mx->io->ImageRecordInt8Iter
mx->io->ImageRecordUInt8Iter
mx->io->MNISTIter
mx->recordio->MXRecordIO
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/io/io.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ A detailed tutorial is available at
io.CSVIter
io.LibSVMIter
io.ImageRecordIter
io.ImageRecordInt8Iter
io.ImageRecordUInt8Iter
io.MNISTIter
recordio.MXRecordIO
Expand Down
8 changes: 7 additions & 1 deletion example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
help='If enabled, the quantize op will '
'be calibrated offline if calibration mode is '
'enabled')
parser.add_argument('--use-quantized-data-layer', type=bool, default=True,
help='If enabled, data layer will be already quantized.')
args = parser.parse_args()
ctx = mx.cpu(0)
logging.basicConfig()
Expand Down Expand Up @@ -273,6 +275,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
ctx=ctx, excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, quantized_dtype=args.quantized_dtype,
use_quantized_data_layer=args.use_quantized_data_layer,
logger=logger)
sym_name = '%s-symbol.json' % (prefix + '-quantized')
else:
Expand All @@ -295,7 +298,10 @@ def save_params(fname, arg_params, aux_params, logger=None):
calib_mode=calib_mode, calib_data=data,
num_calib_examples=num_calib_batches * batch_size,
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
label_names=(label_name,), logger=logger)
label_names=(label_name,),
use_quantized_data_layer=args.use_quantized_data_layer,
logger=logger)

if calib_mode == 'entropy':
suffix = '-quantized-%dbatches-entropy' % num_calib_batches
elif calib_mode == 'naive':
Expand Down
84 changes: 65 additions & 19 deletions example/quantization/imagenet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import os
import time
import numpy as np
import mxnet as mx
from mxnet import nd
from mxnet.contrib.quantization import *
Expand Down Expand Up @@ -98,22 +99,36 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
logger.info(m.get())


def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, logger=None):
# get mod
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, symbol_file)
if logger is not None:
logger.info('Loading symbol from file %s' % symbol_file_path)
sym = mx.sym.load(symbol_file_path)
mod = mx.mod.Module(symbol=sym, context=ctx)
mod.bind(for_training = False,
inputs_need_grad = False,
data_shapes = [('data', (batch_size,)+data_shape)])
if data_layer_type == "float32":
dshape = mx.io.DataDesc(name='data', shape=(
batch_size,) + data_shape, dtype=np.float32)
elif data_layer_type == 'uint8':
dshape = mx.io.DataDesc(name='data', shape=(
batch_size,) + data_shape, dtype=np.uint8)
else: # int8
dshape = mx.io.DataDesc(name='data', shape=(
batch_size,) + data_shape, dtype=np.int8)
mod.bind(for_training=False,
inputs_need_grad=False,
data_shapes=[dshape])
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))

# get data
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, []) # empty label
if data_layer_type == "float32":
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx, dtype=data_layer_type)
for _, shape in mod.data_shapes]
else:
data = [mx.nd.full(shape=shape, val=127, ctx=ctx, dtype=data_layer_type)
for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, []) # empty label

# run
dry_run = 5 # use 5 iterations to warm up
Expand Down Expand Up @@ -152,6 +167,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--data-layer-type', type=str, default="float32",
choices=['float32', 'int8', 'uint8'],
help='data type for data layer')

args = parser.parse_args()

Expand Down Expand Up @@ -192,24 +210,52 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
data_shape = tuple([int(i) for i in image_shape.split(',')])
logger.info('Input data shape = %s' % str(data_shape))

data_layer_type = args.data_layer_type
if args.benchmark == False:
dataset = args.dataset
download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
logger.info('Dataset for inference: %s' % dataset)

# creating data iterator
data = mx.io.ImageRecordIter(path_imgrec=dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=True,
shuffle_chunk_seed=3982304,
seed=48564309,
**combine_mean_std)
if data_layer_type == 'float32':
data = mx.io.ImageRecordIter(path_imgrec=dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=args.shuffle_dataset,
shuffle_chunk_seed=args.shuffle_chunk_seed,
seed=args.shuffle_seed,
**combine_mean_std)
elif data_layer_type == 'uint8':
data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=args.shuffle_dataset,
shuffle_chunk_seed=args.shuffle_chunk_seed,
seed=args.shuffle_seed,
**combine_mean_std)
else: #int8
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=args.shuffle_dataset,
shuffle_chunk_seed=args.shuffle_chunk_seed,
seed=args.shuffle_seed,
**combine_mean_std)

# loading model
sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)
Expand All @@ -224,5 +270,5 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
max_num_examples=num_inference_images, logger=logger)
else:
logger.info('Running model %s for inference' % symbol_file)
speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, logger)
speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, data_layer_type, logger)
logger.info('batch size %2d, image/sec: %f', batch_size, speed)
8 changes: 5 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1571,14 +1571,16 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
* \param excluded_symbols op names to be excluded from being quantized
* \param num_offline number of parameters that are quantized offline
* \param offline_params array of c strings representing the names of params quantized offline
* \param quantized_dtype the quantized destination type for input data.
* \param calib_quantize **Deprecated**. quantize op will always be calibrated if could.
* \param quantized_dtype the quantized destination type for input data
* \param calib_quantize **Deprecated**. quantize op will always be calibrated if could
* \param use_quantized_data_layer if true, use quantized data layer
*/
MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle,
const mx_uint num_excluded_symbols,
const char **excluded_symbols,
const mx_uint num_offline, const char **offline_params,
const char *quantized_dtype, const bool calib_quantize);
const char *quantized_dtype, const bool calib_quantize,
const bool use_quantized_data_layer);

/*!
* \brief Set calibration table to node attributes in the sym
Expand Down
1 change: 1 addition & 0 deletions perl-package/AI-MXNet/lib/AI/MXNet/IO.pm
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ extends 'AI::MXNet::DataIter';
mx->io->CSVIter Returns the CSV file iterator.
mx->io->LibSVMIter Returns the LibSVM iterator which returns data with csr storage type.
mx->io->ImageRecordIter Iterates on image RecordIO files
mx->io->ImageRecordInt8Iter Iterating on image RecordIO files
mx->io->ImageRecordUInt8Iter Iterating on image RecordIO files
mx->io->MNISTIter Iterating on the MNIST dataset.
mx->recordio->MXRecordIO Reads/writes RecordIO data format, supporting sequential read and write.
Expand Down
15 changes: 11 additions & 4 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _quantize_params(qsym, params, th_dict):
quantized_params[name] = ndarray.array([th_dict[output][1]])
return quantized_params

def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8'):
def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8',
use_quantized_data_layer=False):
"""Given a symbol object representing a neural network of data type FP32,
quantize it into a INT8 network.

Expand All @@ -97,6 +98,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_
avoided.
quantized_dtype: str
The quantized destination type for input data.
use_quantized_data_layer bool
If true, use quantized data layer.
"""
num_excluded_symbols = 0
if excluded_symbols is not None:
Expand All @@ -120,7 +123,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_
mx_uint(num_offline),
c_array(ctypes.c_char_p, offline),
c_str(quantized_dtype),
ctypes.c_bool(True)))
ctypes.c_bool(True),
ctypes.c_bool(use_quantized_data_layer)))
return Symbol(out)


Expand Down Expand Up @@ -419,7 +423,7 @@ def quantize_model(sym, arg_params, aux_params,
data_names=('data',), label_names=('softmax_label',),
ctx=cpu(), excluded_sym_names=None, calib_mode='entropy',
calib_data=None, num_calib_examples=None, calib_layer=None,
quantized_dtype='int8', logger=logging):
quantized_dtype='int8', use_quantized_data_layer=False, logger=logging):
"""User-level API for generating a quantized model from a FP32 model w/ or w/o calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
inference using the quantized models on Windows for now.
Expand Down Expand Up @@ -473,6 +477,8 @@ def quantize_model(sym, arg_params, aux_params,
The quantized destination type for input data. Currently support 'int8'
, 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result.
Default value is 'int8'.
use_quantized_data_layer bool
If true, use quantized data layer.
logger : Object
A logging object for printing information during the process of quantization.

Expand All @@ -495,7 +501,8 @@ def quantize_model(sym, arg_params, aux_params,
' expected `int8`, `uint8` or `auto`' % quantized_dtype)
qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names,
offline_params=list(arg_params.keys()),
quantized_dtype=quantized_dtype)
quantized_dtype=quantized_dtype,
use_quantized_data_layer=use_quantized_data_layer)

th_dict = {}
if calib_mode is not None and calib_mode != 'none':
Expand Down
4 changes: 3 additions & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,8 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
const mx_uint num_offline,
const char **offline_params,
const char *quantized_dtype,
const bool calib_quantize) {
const bool calib_quantize,
const bool use_quantized_data_layer) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(sym_handle);
Expand All @@ -668,6 +669,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
g.attrs["excluded_nodes"] = std::make_shared<nnvm::any>(std::move(excluded_node_names));
g.attrs["offline_params"] = std::make_shared<nnvm::any>(std::move(offline));
g.attrs["quantized_dtype"] = std::make_shared<nnvm::any>(std::move(quantized_type));
g.attrs["use_quantized_data_layer"] = std::make_shared<nnvm::any>(use_quantized_data_layer);
g = ApplyPass(std::move(g), "QuantizeGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
Expand Down
55 changes: 45 additions & 10 deletions src/io/iter_image_recordio_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
float RGBA_MULT[4] = { 0 };
float RGBA_BIAS[4] = { 0 };
float RGBA_MEAN[4] = { 0 };
int16_t RGBA_MEAN_INT[4] = {0};
mshadow::Tensor<cpu, 3, DType>& data = (*data_ptr);
if (!std::is_same<DType, uint8_t>::value) {
RGBA_MULT[0] = contrast_scaled / normalize_param_.std_r;
Expand All @@ -387,6 +388,10 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
RGBA_MEAN[1] = normalize_param_.mean_g;
RGBA_MEAN[2] = normalize_param_.mean_b;
RGBA_MEAN[3] = normalize_param_.mean_a;
RGBA_MEAN_INT[0] = std::round(normalize_param_.mean_r);
RGBA_MEAN_INT[1] = std::round(normalize_param_.mean_g);
RGBA_MEAN_INT[2] = std::round(normalize_param_.mean_b);
RGBA_MEAN_INT[3] = std::round(normalize_param_.mean_a);
}
}

Expand All @@ -408,17 +413,30 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
for (int i = 0; i < res.rows; ++i) {
const uchar* im_data = res.ptr<uchar>(i);
for (int j = 0; j < res.cols; ++j) {
for (int k = 0; k < n_channels; ++k) {
RGBA[k] = im_data[swap_indices[k]];
}
if (!std::is_same<DType, uint8_t>::value) {
// normalize/mirror here to avoid memory copies
// logic from iter_normalize.h, function SetOutImg
if (std::is_same<DType, int8_t>::value) {
if (meanfile_ready_) {
for (int k = 0; k < n_channels; ++k) {
RGBA[k] = cv::saturate_cast<int8_t>(im_data[swap_indices[k]] -
static_cast<int16_t>(std::round(meanimg_[k][i][j])));
}
} else {
for (int k = 0; k < n_channels; ++k) {
RGBA[k] = cv::saturate_cast<int8_t>(im_data[swap_indices[k]] - RGBA_MEAN_INT[k]);
}
}
} else {
for (int k = 0; k < n_channels; ++k) {
if (meanfile_ready_) {
RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k];
} else {
RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k];
RGBA[k] = im_data[swap_indices[k]];
}
if (!std::is_same<DType, uint8_t>::value) {
// normalize/mirror here to avoid memory copies
// logic from iter_normalize.h, function SetOutImg
for (int k = 0; k < n_channels; ++k) {
if (meanfile_ready_) {
RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k];
} else {
RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k];
}
}
}
}
Expand Down Expand Up @@ -795,5 +813,22 @@ the data type instead of ``float``.
.set_body([]() {
return new ImageRecordIter2<uint8_t>();
});

MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter)
.describe(R"code(Iterating on image RecordIO files

This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as
the data type instead of ``float``.

)code" ADD_FILELINE)
.add_arguments(ImageRecParserParam::__FIELDS__())
.add_arguments(ImageRecordParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.add_arguments(ListDefaultAugParams())
.set_body([]() {
return new ImageRecordIter2<int8_t>();
});

} // namespace io
} // namespace mxnet
Loading