From b6972bb055fc44481b072db3abb90e26ee27c787 Mon Sep 17 00:00:00 2001 From: YixinBao Date: Thu, 8 Aug 2019 18:20:19 +0800 Subject: [PATCH] add int8 bn mkldnn implementation and test (#15664) * add int8 bn mkldnn implementation and test * fix lint * fix ci * enable int8 bn test only in mkldnn backend * disable int8 bn forward test with gpu backend * update int8 bn with reference to comments * fix lint * disable int8 bn gluon forward test with gpu backend * disable uint8 bn forward test with mkldnn backend * restore support mkldnn bn condition * rm duplicate code --- cpp-package/scripts/OpWrapperGenerator.py | 1 + src/operator/nn/batch_norm-inl.h | 52 +++++-- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 29 ++-- .../mkldnn/mkldnn_quantized_batch_norm.cc | 123 ++++++++++++++++ .../quantization/quantize_graph_pass.cc | 42 +++++- .../quantization/quantized_batch_norm.cc | 137 ++++++++++++++++++ .../python/quantization/test_quantization.py | 90 +++++++++++- 7 files changed, 446 insertions(+), 28 deletions(-) create mode 100644 src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc create mode 100644 src/operator/quantization/quantized_batch_norm.cc diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 65ba247c25c8..853c519a73d4 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -93,6 +93,7 @@ class Arg: 'int (non-negative)': 'uint32_t',\ 'long (non-negative)': 'uint64_t',\ 'int or None':'dmlc::optional',\ + 'float or None':'dmlc::optional',\ 'long':'int64_t',\ 'double':'double',\ 'double or None':'dmlc::optional',\ diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index a82771cd4401..17a16db5adcd 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -53,11 +53,19 @@ enum BatchNormOpOutputs {kOut, kMean, kVar}; // req, out_data enum BatchNormOpResource {kTempSpace}; enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; // aux_states -/*! \brief Default channel axis if none specified int he params */ +/*! \brief Default channel axis if none specified in the params */ constexpr int DEFAULT_AXIS = 1; } // namespace batchnorm /*! \brief Parameters for BatchNorm operator */ +namespace quantized_batchnorm { +enum QuantizedBatchNormOpInputs {kData, kGamma, kBeta, kInMovingMean, + kInMovingVar, kDataMin, kDataMax}; +enum QuantizedBatchNormOutputs {kOut, kOutMin, kOutMax}; +enum QuantizedBatchNormOpAuxiliary {kMovingMean, kMovingVar}; +} // quantized_batchnorm + +/*! \brief Parameters for BatchNoram operator */ struct BatchNormParam : public dmlc::Parameter { double eps; float momentum; @@ -66,6 +74,10 @@ struct BatchNormParam : public dmlc::Parameter { bool output_mean_var; int axis; bool cudnn_off; + + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + DMLC_DECLARE_PARAMETER(BatchNormParam) { DMLC_DECLARE_FIELD(eps).set_default(1e-3f) .describe("Epsilon to prevent div 0. " @@ -81,19 +93,37 @@ struct BatchNormParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(output_mean_var).set_default(false) .describe("Output the mean and inverse std "); DMLC_DECLARE_FIELD(axis).set_default(mxnet::op::batchnorm::DEFAULT_AXIS) - .describe("Specify which shape axis the channel is specified"); + .describe("Specify which shape axis the channel is specified"); DMLC_DECLARE_FIELD(cudnn_off).set_default(false) - .describe("Do not select CUDNN operator, if available"); + .describe("Do not select CUDNN operator, if available"); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized batch norm op to calculate primitive scale." + "Note: this calib_range is to calib bn output."); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized batch norm op to calculate primitive scale." + "Note: this calib_range is to calib bn output."); } - bool operator==(const BatchNormParam& other) const { - return this->eps == other.eps && - this->momentum == other.momentum && - this->fix_gamma == other.fix_gamma && - this->use_global_stats == other.use_global_stats && - this->output_mean_var == other.output_mean_var && - this->axis == other.axis && - this->cudnn_off == other.cudnn_off; + bool operator==(const BatchNormParam &other) const { + bool flag = this->eps == other.eps && this->momentum == other.momentum && + this->fix_gamma == other.fix_gamma && + this->use_global_stats == other.use_global_stats && + this->output_mean_var == other.output_mean_var && this->axis == other.axis && + this->cudnn_off == other.cudnn_off && + this->min_calib_range.has_value() == other.min_calib_range.has_value() && + this->max_calib_range.has_value() == other.max_calib_range.has_value(); + if (this->min_calib_range.has_value() && other.min_calib_range.has_value() && + this->max_calib_range.has_value() && other.max_calib_range.has_value()) { + flag = flag && this->min_calib_range.value() == other.min_calib_range.value() && + this->max_calib_range.value() == other.max_calib_range.value(); + } + return flag; } }; diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 403baaa94ab4..f294153ecc24 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -132,8 +132,8 @@ class MKLDNNBNForward { return *var_m; } - void SetDataHandle(const NDArray &data, const NDArray &mean, - const NDArray &var, const mkldnn::memory &out) { + void SetDataHandle(const NDArray &data, const mkldnn::memory *mean, + const mkldnn::memory *var, const mkldnn::memory *out) { auto _data = data.GetMKLDNNData(); if (data_m) { data_m->set_data_handle(_data->get_data_handle()); @@ -142,24 +142,22 @@ class MKLDNNBNForward { _data->get_data_handle())); } if (out_m) { - out_m->set_data_handle(out.get_data_handle()); + out_m->set_data_handle(out->get_data_handle()); } else { - out_m.reset(new mkldnn::memory(out.get_primitive_desc(), - out.get_data_handle())); + out_m.reset(new mkldnn::memory(out->get_primitive_desc(), + out->get_data_handle())); } - auto mean_ptr = mean.data().dptr_; if (mean_m) { - mean_m->set_data_handle(mean_ptr); + mean_m->set_data_handle(mean->get_data_handle()); } else { - mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), - mean_ptr)); + mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(), + mean->get_data_handle())); } - auto var_ptr = var.data().dptr_; if (var_m) { - var_m->set_data_handle(var_ptr); + var_m->set_data_handle(var->get_data_handle()); } else { - var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(), - var_ptr)); + var_m.reset(new mkldnn::memory(var->get_primitive_desc(), + var->get_data_handle())); } if (fwd == nullptr) { @@ -175,6 +173,11 @@ class MKLDNNBNForward { } } + void SetDataHandle(const NDArray &data, const NDArray &mean, + const NDArray &var, const mkldnn::memory &out) { + SetDataHandle(data, mean.GetMKLDNNData(), var.GetMKLDNNData(), &out); + } + const mkldnn::batch_normalization_forward &GetFwd() const { return *fwd; } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc new file mode 100644 index 000000000000..df5e48744f2d --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_quantized_batch_norm.cc + * \brief + * \author Yixin Bao +*/ + +#if MXNET_USE_MKLDNN == 1 +#include "../../nn/mkldnn/mkldnn_batch_norm-inl.h" +#include "../quantization_utils.h" + +namespace mxnet { +namespace op { + +static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(in_data.size(), 7U); + CHECK_EQ(outputs.size(), 3U); + + TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); + const BatchNormParam ¶m = nnvm::get(attrs.parsed); + const NDArray &data = in_data[quantized_batchnorm::kData]; + const size_t channelAxis = static_cast( + param.axis < 0 ? static_cast(data.shape().ndim()) + param.axis : param.axis); + const int channel_count = data.shape()[channelAxis]; + const float min_data = in_data[quantized_batchnorm::kDataMin].data().dptr()[0]; + const float max_data = in_data[quantized_batchnorm::kDataMax].data().dptr()[0]; + const float max_abs_data = std::max(std::abs(min_data), std::abs(max_data)); + + float *min_output_ptr = outputs[quantized_batchnorm::kOutMin].data().dptr(); + float *max_output_ptr = outputs[quantized_batchnorm::kOutMax].data().dptr(); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + *max_output_ptr = param.max_calib_range.value(); + *min_output_ptr = param.min_calib_range.value(); + } else { + LOG(FATAL) << "min_calib_range or max_calib_range is not available. Quantized BN currently " + "don't support calib_mode=None"; + } + const float max_abs_output = std::max(std::abs(*min_output_ptr), std::abs(*max_output_ptr)); + + unsigned flags = mkldnn::use_global_stats | mkldnn::use_scale_shift; + auto &fwd = GetBNForward(param, ctx, data, flags); + const mkldnn::memory &weight_mem = fwd.GetWeight(); + CHECK_EQ(weight_mem.get_primitive_desc().get_size(), channel_count * sizeof(float) * 2); + float *weight_buf = reinterpret_cast(weight_mem.get_data_handle()); + + float *gamma_ptr = in_data[quantized_batchnorm::kGamma].data().dptr(); + float *beta_ptr = in_data[quantized_batchnorm::kBeta].data().dptr(); + + const NDArray &moving_mean = in_data[quantized_batchnorm::kInMovingMean]; + const NDArray &moving_var = in_data[quantized_batchnorm::kInMovingVar]; + float *moving_mean_ptr = moving_mean.data().dptr(); + float *moving_var_ptr = moving_var.data().dptr(); + + // rescale gamma and beta, to make mean=0 and var=1 + auto rescaled_mean_mem = + TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_primitive_desc()); + auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_primitive_desc()); + float *rescaled_mean_ptr = reinterpret_cast(rescaled_mean_mem->get_data_handle()); + float *rescaled_var_ptr = reinterpret_cast(rescaled_var_mem->get_data_handle()); + +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (int channel = 0; channel < channel_count; ++channel) { + float invstd = 1.0 / std::sqrt(moving_var_ptr[channel] + param.eps); + weight_buf[channel] = gamma_ptr[channel] * invstd * max_abs_data / max_abs_output; + weight_buf[channel_count + channel] = + (beta_ptr[channel] - moving_mean_ptr[channel] * gamma_ptr[channel] * invstd) * kInt8Range / + max_abs_output; + rescaled_mean_ptr[channel] = 0.0f; + rescaled_var_ptr[channel] = 1.0f; + } + + auto out_mem = CreateMKLDNNMem(outputs[batchnorm::kOut], + fwd.GetPd().dst_primitive_desc(), req[batchnorm::kOut], &data); + fwd.SetDataHandle(data, rescaled_mean_mem, rescaled_var_mem, out_mem.second); + + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + MKLDNNStream::Get()->Submit(); +} + +inline static bool QuantizedBatchNormStorageType(const nnvm::NodeAttrs &attrs, const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + bool dispatched = false; + if (!dispatched) { + dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); + } + return dispatched; +} + +NNVM_REGISTER_OP(_contrib_quantized_batch_norm) +.set_attr("FInferStorageType", QuantizedBatchNormStorageType) +.set_attr("FComputeEx", MKLDNNQuantizedBatchNormForward) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("TIsMKLDNN", true); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 412e78e70fff..31e3539b8a84 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -37,6 +37,19 @@ using nnvm::NodePtr; using nnvm::NodeEntry; using nnvm::Graph; +inline size_t GetNumOutputs(NodePtr node) { + // Get NumOutputs, check if current node has NumVisibleOutputs function, if yes, return + // num_visible_outputs + size_t num_outputs = node->num_outputs(); + static const auto& num_visible_outputs_attr = + nnvm::Op::GetAttr("FNumVisibleOutputs"); + auto num_visible_output_func = num_visible_outputs_attr.get(node->op(), nullptr); + if (num_visible_output_func != nullptr) { + num_outputs = num_visible_output_func(node->attrs); + } + return num_outputs; +} + NodePtr CreateNode(std::string op_name, std::string node_name) { NodePtr node = Node::Create(); node->attrs.name = node_name; @@ -223,7 +236,7 @@ Graph QuantizeGraph(Graph &&src) { // calculate min/max index from mirror node) based on assumption that // there is only 1min and 1max output from mirror node (which is // currently true) - size_t num_outputs = mirror_node->num_outputs() - 2; + size_t num_outputs = GetNumOutputs(mirror_node) - 2; min_index = num_outputs + 2 * e.index; max_index = num_outputs + 2 * e.index + 1; } else { @@ -276,7 +289,7 @@ Graph QuantizeGraph(Graph &&src) { // calculate min/max index from mirror node) based on assumption that // there is only 1 min and 1 max output from mirror node (which is // currently true) - size_t num_outputs = mirror_node->num_outputs() - 2; + size_t num_outputs = GetNumOutputs(mirror_node) - 2; uint32_t min_index = num_outputs + 2 * e.index; uint32_t max_index = num_outputs + 2 * e.index + 1; NodePtr dequantize_node = CreateNode("_contrib_dequantize", @@ -309,7 +322,7 @@ Graph QuantizeGraph(Graph &&src) { // calculate min/max index from mirror node) based on assumption that // there is only 1 min and 1 max output from mirror node (which is // currently true) - size_t num_outputs = e.node->num_outputs(); + size_t num_outputs = GetNumOutputs(e.node); uint32_t min_index = num_outputs + 2 * e.index; uint32_t max_index = num_outputs + 2 * e.index + 1; @@ -403,6 +416,29 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) { << "` has negative input, consider use `auto` or `int8` as out_type"; } } + } else if (node->op() == Op::Get("_contrib_quantized_batch_norm")) { + auto quantized_op_idx = node->inputs[0].index; + const std::string prefix = "quantized_"; + std::string out_data_name = node->attrs.name.substr(prefix.size()); + if (node->op()) { + auto list_output_names_func = flist_outputs.get(node->op(), nullptr); + // We want to get the pre-calculated min_range and max_range from the calibration table for + // out_data. Here we create the output data name same as its constructed in + // GraphExecutor::ExecuteMonCallback. + if (list_output_names_func != nullptr) { + std::vector names = list_output_names_func(node->attrs); + out_data_name += "_" + names[quantized_op_idx]; + } else { + out_data_name += "_" + std::to_string(quantized_op_idx); + } + } + + const auto calib_table_iter = calib_table.find(out_data_name); + if (calib_table_iter != calib_table.end()) { + node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); + node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); + node->op()->attr_parser(&(node->attrs)); + } } }); return g; diff --git a/src/operator/quantization/quantized_batch_norm.cc b/src/operator/quantization/quantized_batch_norm.cc new file mode 100644 index 000000000000..3187826fe996 --- /dev/null +++ b/src/operator/quantization/quantized_batch_norm.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file quantized_batch_norm.cc + * \brief + * \author Yixin Bao +*/ +#include +#include "../nn/batch_norm-inl.h" +#if MXNET_USE_MKLDNN == 1 +#include "../nn/mkldnn/mkldnn_batch_norm-inl.h" +#endif + +namespace mxnet { +namespace op { + +bool QuantizedBatchNormShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const BatchNormParam& param = nnvm::get(attrs.parsed); + using namespace mshadow; + CHECK_EQ(in_shape->size(), 7U) + << "Input:[data, gamma, beta, moving_mean, moving_var, min_data, max_data]"; + CHECK_EQ(out_shape->size(), 3U); + + const mxnet::TShape& dshape = in_shape->at(batchnorm::kData); + if (!mxnet::ndim_is_known(dshape)) { + return false; + } + const int channelAxis = param.axis < 0 ? dshape.ndim() + param.axis : param.axis; + CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis; + const int channelCount = dshape[channelAxis]; + + SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape(Shape1(channelCount))) // gamma,beta + SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape(Shape1(channelCount))) + SHAPE_ASSIGN_CHECK(*in_shape, 3, mxnet::TShape(Shape1(channelCount))); // moving_mean, moving_var + SHAPE_ASSIGN_CHECK(*in_shape, 4, mxnet::TShape(Shape1(channelCount))) + SHAPE_ASSIGN_CHECK(*in_shape, 5, mxnet::TShape(1, 1)); // min_data, max_data + SHAPE_ASSIGN_CHECK(*in_shape, 6, mxnet::TShape(1, 1)); + + SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape); + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1)); // min_output, max_output + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1)); + return true; +} + +bool QuantizedBatchNormType(const nnvm::NodeAttrs& attrs, std::vector* in_type, + std::vector* out_type) { + using namespace mshadow; + CHECK_EQ(in_type->size(), 7U); + CHECK_EQ(out_type->size(), 3U); + + TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); + for (size_t i = 1; i < 7; ++i) { + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32); + } + + TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); + + return true; +} + +NNVM_REGISTER_OP(_contrib_quantized_batch_norm) +.describe(R"code(BatchNorm operator for input and output data type of int8. +The input and output data comes with min and max thresholds for quantizing +the float32 data into int8. + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training. +)code" ADD_FILELINE) +.set_num_inputs(7) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "gamma", "beta", + "moving_mean", "moving_var", "min_data", "max_data"}; + }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "min_output", "max_output"}; + }) +.set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { + return std::vector{3, 4}; +}) +.set_attr("FInferShape", QuantizedBatchNormShape) +.set_attr("FInferType", QuantizedBatchNormType) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return false; }) +.add_argument("data", "NDArray-or-Symbol", "Input data.") +.add_argument("gamma", "NDArray-or-Symbol", "gamma.") +.add_argument("beta", "NDArray-or-Symbol", "beta.") +.add_argument("moving_mean", "NDArray-or-Symbol", "moving_mean.") +.add_argument("moving_var", "NDArray-or-Symbol", "moving_var.") +.add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.") +.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.") +.add_arguments(BatchNormParam::__FIELDS__()); + +NNVM_REGISTER_OP(BatchNorm) +.set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_batch_norm"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }) +.set_attr("FAvoidQuantizeInput", [](const NodeAttrs &attrs, size_t index) { + if (index == 0) + return false; + else + return true; +}); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index a991417b28c9..0cf38b7375a8 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -594,6 +594,88 @@ def check_quantized_act(data_shape, qdtype): check_quantized_act((10, 15, 18), qdtype) check_quantized_act((3, 4, 23, 23), qdtype) +@with_seed() +def test_quantized_bn(): + def get_mean_var(data): + mean = mx.ndarray.mean(data, axis=1, exclude=1) + mean_broad = mx.ndarray.expand_dims(mean, axis=0) + mean_broad = mx.ndarray.expand_dims(mean_broad, axis=2) + mean_broad = mx.ndarray.expand_dims(mean_broad, axis=3) + mean_broad = mx.ndarray.broadcast_like(mean_broad, data) + var = mx.ndarray.multiply(data - mean_broad, data - mean_broad) + var = mx.ndarray.mean(var, axis=1, exclude=1) + return mean, var + + def check_quantized_bn(data_shape, qdtype): + if qdtype == 'uint8': + print('skipped testing quantize_bn for uint8 since it is not supported yet') + return + elif is_test_for_native_cpu(): + print('skipped testing quantize_bn for native cpu since it is not supported yet') + return + elif is_test_for_gpu(): + print('skipped testing quantize_bn for gpu since it is not supported yet') + return + + # qdtype = int8 + data_low = -127.0 + data_high = 127.0 + quantized_range = 127.0 + # run fp32 bn + data_sym = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + bn_fp32 = mx.sym.BatchNorm(data=data_sym, name='bn', use_global_stats=True, fix_gamma=False) + arg_shapes, out_shapes, aux_shapes = bn_fp32.infer_shape(data=data_shape) + arg_names = bn_fp32.list_arguments() + aux_names = bn_fp32.list_auxiliary_states() + + data = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape) + gamma = mx.nd.random.uniform(low=data_low, high=data_high, shape=arg_shapes[1]) + beta = mx.nd.random.uniform(low=data_low, high=data_high, shape=arg_shapes[2]) + moving_mean, moving_var = get_mean_var(data) + + bn_fp32_exe = bn_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') + bn_fp32_exe.arg_dict[arg_names[0]][:] = data + bn_fp32_exe.arg_dict[arg_names[1]][:] = gamma + bn_fp32_exe.arg_dict[arg_names[2]][:] = beta + bn_fp32_exe.aux_dict[aux_names[0]][:] = moving_mean + bn_fp32_exe.aux_dict[aux_names[1]][:] = moving_var + min_data = mx.nd.min(data) + max_data = mx.nd.max(data) + data_range = mx.nd.maximum(mx.nd.abs(min_data), mx.nd.abs(max_data)) + + output= bn_fp32_exe.forward()[0] + + # generate int8 bn from fp32 bn + arg_params = dict() + for k,v in bn_fp32_exe.arg_dict.items(): + if 'data' in k or 'softmax_label' in k: + continue + arg_params[k] = v + + calib_data = NDArrayIter(data=data, batch_size=data_shape[0]) + calib_data = DummyIter(calib_data) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=bn_fp32, + arg_params=arg_params, + aux_params=bn_fp32_exe.aux_dict, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + num_calib_examples=20) + + mod = mx.mod.Module(symbol=qsym, label_names=None, context=mx.current_context()) + mod.bind(for_training=False, data_shapes=[('data', data_shape)]) + mod.set_params(qarg_params, qaux_params) + batch = mx.io.DataBatch([data], []) + mod.forward(batch, is_train=False) + output_int8_to_fp32= mod.get_outputs()[0] + + assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=3) + + check_quantized_bn((32, 512, 4, 4), 'int8') + check_quantized_bn((32, 1024, 8, 8), 'int8') + check_quantized_bn((32, 3, 224, 224), 'int8') + @with_seed() def test_quantize_params(): data = mx.sym.Variable('data') @@ -835,6 +917,12 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N if qdtype == 'int8' and is_test_for_mkldnn() and name in ['sym1', 'sym2', 'sym3']: print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet') continue + elif qdtype == 'uint8' and is_test_for_mkldnn() and name in ['sym1']: + print('skipping test_quantize_model_with_forward for mkldnn cpu uint8 since it is not supported yet') + continue + elif qdtype == 'int8' and is_test_for_gpu() and name in ['sym1']: + print('skipped testing test_quantize_model_with_forward for gpu int8 since it is not supported yet') + continue if lshape is None: mod = Module(symbol=s, label_names=None) @@ -905,7 +993,7 @@ def check_quantize_net(qdtype): if is_test_for_native_cpu(): print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet') return - elif qdtype == 'uint8' and is_test_for_gpu(): + elif is_test_for_gpu(): print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet') return