diff --git a/docs/api/perl/io.md b/docs/api/perl/io.md index be4976425553..ca3b0f1e09f7 100644 --- a/docs/api/perl/io.md +++ b/docs/api/perl/io.md @@ -69,6 +69,7 @@ Then we can call `$mod->fit($nd_iter, num_epoch=>2)` to train `loss` by 2 epochs mx->io->NDArrayIter mx->io->CSVIter mx->io->ImageRecordIter +mx->io->ImageRecordInt8Iter mx->io->ImageRecordUInt8Iter mx->io->MNISTIter mx->recordio->MXRecordIO diff --git a/docs/api/python/io/io.md b/docs/api/python/io/io.md index c0dc8d1efb1d..13a612196eee 100644 --- a/docs/api/python/io/io.md +++ b/docs/api/python/io/io.md @@ -75,6 +75,7 @@ A detailed tutorial is available at io.CSVIter io.LibSVMIter io.ImageRecordIter + io.ImageRecordInt8Iter io.ImageRecordUInt8Iter io.MNISTIter recordio.MXRecordIO diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py index 0725165b0ca5..47e206303e99 100644 --- a/example/quantization/imagenet_inference.py +++ b/example/quantization/imagenet_inference.py @@ -19,6 +19,7 @@ import logging import os import time +import numpy as np import mxnet as mx from mxnet import nd from mxnet.contrib.quantization import * @@ -98,7 +99,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger.info(m.get()) -def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): +def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, logger=None): # get mod cur_path = os.path.dirname(os.path.realpath(__file__)) symbol_file_path = os.path.join(cur_path, symbol_file) @@ -106,14 +107,28 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): logger.info('Loading symbol from file %s' % symbol_file_path) sym = mx.sym.load(symbol_file_path) mod = mx.mod.Module(symbol=sym, context=ctx) - mod.bind(for_training = False, - inputs_need_grad = False, - data_shapes = [('data', (batch_size,)+data_shape)]) + if data_layer_type == "int8": + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.int8) + elif data_layer_type == 'uint8': + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.uint8) + else: # float32 + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.float32) + mod.bind(for_training=False, + inputs_need_grad=False, + data_shapes=[dshape]) mod.init_params(initializer=mx.init.Xavier(magnitude=2.)) # get data - data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes] - batch = mx.io.DataBatch(data, []) # empty label + if data_layer_type == "float32": + data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx, dtype=data_layer_type) + for _, shape in mod.data_shapes] + else: + data = [mx.nd.full(shape=shape, val=127, ctx=ctx, dtype=data_layer_type) + for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) # empty label # run dry_run = 5 # use 5 iterations to warm up @@ -152,6 +167,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): help='shuffling seed, see' ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' ' for more details') + parser.add_argument('--data-layer-type', type=str, default="float32", + choices=['float32', 'int8', 'uint8'], + help='data type for data layer') args = parser.parse_args() @@ -192,24 +210,52 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): data_shape = tuple([int(i) for i in image_shape.split(',')]) logger.info('Input data shape = %s' % str(data_shape)) + data_layer_type = args.data_layer_type if args.benchmark == False: dataset = args.dataset download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset) logger.info('Dataset for inference: %s' % dataset) # creating data iterator - data = mx.io.ImageRecordIter(path_imgrec=dataset, - label_width=1, - preprocess_threads=data_nthreads, - batch_size=batch_size, - data_shape=data_shape, - label_name=label_name, - rand_crop=False, - rand_mirror=False, - shuffle=True, - shuffle_chunk_seed=3982304, - seed=48564309, - **combine_mean_std) + if data_layer_type == 'int8': + data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **combine_mean_std) + elif data_layer_type == 'uint8': + data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **combine_mean_std) + else: #float32 + data = mx.io.ImageRecordIter(path_imgrec=dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **combine_mean_std) # loading model sym, arg_params, aux_params = load_model(symbol_file, param_file, logger) @@ -224,5 +270,5 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): max_num_examples=num_inference_images, logger=logger) else: logger.info('Running model %s for inference' % symbol_file) - speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, logger) + speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, data_layer_type, logger) logger.info('batch size %2d, image/sec: %f', batch_size, speed) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 76a4995d15c0..9a24b7516128 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1602,8 +1602,8 @@ MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym, * \param excluded_symbols op names to be excluded from being quantized * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline - * \param quantized_dtype the quantized destination type for input data. - * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could. + * \param quantized_dtype the quantized destination type for input data + * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm index 297ceb8c0b24..19e7cfdb8fe3 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm @@ -642,6 +642,7 @@ extends 'AI::MXNet::DataIter'; mx->io->CSVIter Returns the CSV file iterator. mx->io->LibSVMIter Returns the LibSVM iterator which returns data with csr storage type. mx->io->ImageRecordIter Iterates on image RecordIO files + mx->io->ImageRecordInt8Iter Iterating on image RecordIO files mx->io->ImageRecordUInt8Iter Iterating on image RecordIO files mx->io->MNISTIter Iterating on the MNIST dataset. mx->recordio->MXRecordIO Reads/writes RecordIO data format, supporting sequential read and write. diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc index 5d5261b22611..0834dd7786ee 100644 --- a/src/io/iter_image_recordio_2.cc +++ b/src/io/iter_image_recordio_2.cc @@ -372,6 +372,7 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, float RGBA_MULT[4] = { 0 }; float RGBA_BIAS[4] = { 0 }; float RGBA_MEAN[4] = { 0 }; + int16_t RGBA_MEAN_INT[4] = {0}; mshadow::Tensor& data = (*data_ptr); if (!std::is_same::value) { RGBA_MULT[0] = contrast_scaled / normalize_param_.std_r; @@ -387,6 +388,10 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, RGBA_MEAN[1] = normalize_param_.mean_g; RGBA_MEAN[2] = normalize_param_.mean_b; RGBA_MEAN[3] = normalize_param_.mean_a; + RGBA_MEAN_INT[0] = std::round(normalize_param_.mean_r); + RGBA_MEAN_INT[1] = std::round(normalize_param_.mean_g); + RGBA_MEAN_INT[2] = std::round(normalize_param_.mean_b); + RGBA_MEAN_INT[3] = std::round(normalize_param_.mean_a); } } @@ -408,17 +413,30 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, for (int i = 0; i < res.rows; ++i) { const uchar* im_data = res.ptr(i); for (int j = 0; j < res.cols; ++j) { - for (int k = 0; k < n_channels; ++k) { - RGBA[k] = im_data[swap_indices[k]]; - } - if (!std::is_same::value) { - // normalize/mirror here to avoid memory copies - // logic from iter_normalize.h, function SetOutImg + if (std::is_same::value) { + if (meanfile_ready_) { + for (int k = 0; k < n_channels; ++k) { + RGBA[k] = cv::saturate_cast(im_data[swap_indices[k]] - + static_cast(std::round(meanimg_[k][i][j]))); + } + } else { + for (int k = 0; k < n_channels; ++k) { + RGBA[k] = cv::saturate_cast(im_data[swap_indices[k]] - RGBA_MEAN_INT[k]); + } + } + } else { for (int k = 0; k < n_channels; ++k) { - if (meanfile_ready_) { - RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k]; - } else { - RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k]; + RGBA[k] = im_data[swap_indices[k]]; + } + if (!std::is_same::value) { + // normalize/mirror here to avoid memory copies + // logic from iter_normalize.h, function SetOutImg + for (int k = 0; k < n_channels; ++k) { + if (meanfile_ready_) { + RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k]; + } else { + RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k]; + } } } } @@ -795,5 +813,22 @@ the data type instead of ``float``. .set_body([]() { return new ImageRecordIter2(); }); + +MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter) +.describe(R"code(Iterating on image RecordIO files + +This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as +the data type instead of ``float``. + +)code" ADD_FILELINE) +.add_arguments(ImageRecParserParam::__FIELDS__()) +.add_arguments(ImageRecordParam::__FIELDS__()) +.add_arguments(BatchParam::__FIELDS__()) +.add_arguments(PrefetcherParam::__FIELDS__()) +.add_arguments(ListDefaultAugParams()) +.set_body([]() { + return new ImageRecordIter2(); + }); + } // namespace io } // namespace mxnet diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e201d290e8c6..d6060e54a82c 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -123,13 +123,32 @@ static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const OpContex const std::vector& req, const std::vector& outputs) { const QuantizeV2Param& param = nnvm::get(attrs.parsed); - auto out_type = GetOutputType(param); - if (out_type == mshadow::kUint8) { - MKLDNNQuantizeComputeKer(inputs, outputs, param, req); - } else if (out_type == mshadow::kInt8) { - MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + if (inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == mshadow::kInt8) { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + *outputs[1].data().dptr() = param.min_calib_range.value(); + *outputs[2].data().dptr() = param.max_calib_range.value(); + } else { + if (inputs[0].dtype() == mshadow::kUint8) { + *outputs[1].data().dptr() = 0; + *outputs[2].data().dptr() = 255; + } else { + *outputs[1].data().dptr() = -127; + *outputs[2].data().dptr() = 127; + } + } + if (req[0] != kWriteInplace) { + const_cast(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData()); + MKLDNNStream::Get()->Submit(); + } } else { - LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; + auto out_type = GetOutputType(param); + if (out_type == mshadow::kUint8) { + MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + } else if (out_type == mshadow::kInt8) { + MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + } else { + LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; + } } } diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 7a0998383824..e3c411931eba 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -137,51 +137,67 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, Stream *s = ctx.get_stream(); const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); auto out_type = GetOutputType(param); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - if (out_type == mshadow::kUint8) { - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), - param.max_calib_range.value(), MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), - param.max_calib_range.value(), MinAbs(MaxValue(), MinValue())); + + if (inputs[0].type_flag_ == mshadow::kUint8 || inputs[0].type_flag_ == mshadow::kInt8) { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + *outputs[1].dptr() = param.min_calib_range.value(); + *outputs[2].dptr() = param.max_calib_range.value(); } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + if (inputs[0].type_flag_ == mshadow::kUint8) { + *outputs[1].dptr() = 0; + *outputs[2].dptr() = 255; + } else { + *outputs[1].dptr() = -127; + *outputs[2].dptr() = 127; + } } - } else { // model is not calibrated - mxnet::TShape src_shape, dst_shape; - const size_t actual_float_size = sizeof(float); - const size_t temp_reduce_size = - ConfigReduce(s, inputs[0].shape_, mxnet::TShape({1}), - &src_shape, &dst_shape); - Tensor temp_space = ctx.requested[0].get_space_typed( - Shape1(2 * actual_float_size + temp_reduce_size), s); - const int dev_id = ctx.run_ctx.ctx.dev_id; - TBlob in_min_t(reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, - dev_id); - TBlob in_max_t(reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, - dev_id); - Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, - Shape1(temp_reduce_size), s); - broadcast::Reduce( - s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); - broadcast::Reduce( - s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); - if (out_type == mshadow::kUint8) { - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), - in_max_t.dptr(), MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), - in_max_t.dptr(), MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + UnaryOp::IdentityCompute(attrs, ctx, {inputs[0]}, req, outputs); + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } + } else { // model is not calibrated + mxnet::TShape src_shape, dst_shape; + const size_t actual_float_size = sizeof(float); + const size_t temp_reduce_size = ConfigReduce( + s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape); + Tensor temp_space = ctx.requested[0].get_space_typed( + Shape1(2 * actual_float_size + temp_reduce_size), s); + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob in_min_t(reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, + dev_id); + TBlob in_max_t(reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, + dev_id); + Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, + Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + broadcast::Reduce( + s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } } } } @@ -202,7 +218,8 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 3U); const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); - TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32); + CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 || + in_attrs->at(0) == mshadow::kInt8); auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index e221d580d228..300cdfe3b751 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -88,6 +88,12 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr("FComputeEx", MKLDNNQuantizeV2Compute) #endif .set_attr("FCompute", QuantizeV2Compute) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; +}) +.set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ + return std::vector{true}; +}) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { diff --git a/tests/python/train/test_dtype.py b/tests/python/train/test_dtype.py index 2e3ff06d2e18..39bfbcdeeafe 100644 --- a/tests/python/train/test_dtype.py +++ b/tests/python/train/test_dtype.py @@ -65,6 +65,30 @@ def get_iterator_uint8(kv): return (train, val) +def get_iterator_int8(kv): + data_shape = (3, 28, 28) + + train = mx.io.ImageRecordInt8Iter( + path_imgrec = "data/cifar/train.rec", + data_shape = data_shape, + batch_size = batch_size, + rand_crop = True, + rand_mirror = True, + num_parts = kv.num_workers, + part_index = kv.rank) + train = mx.io.PrefetchingIter(train) + + val = mx.io.ImageRecordInt8Iter( + path_imgrec = "data/cifar/test.rec", + rand_crop = False, + rand_mirror = False, + data_shape = data_shape, + batch_size = batch_size, + num_parts = kv.num_workers, + part_index = kv.rank) + + return (train, val) + def get_iterator_float32(kv): data_shape = (3, 28, 28) @@ -190,5 +214,10 @@ def test_cifar10(): run_cifar10(train, val, use_module=False) run_cifar10(train, val, use_module=True) + # test int8 input + (train, val) = get_iterator_int8(kv) + run_cifar10(train, val, use_module=False) + run_cifar10(train, val, use_module=True) + if __name__ == "__main__": test_cifar10()