diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 06a1272caf21..482127ba355c 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -187,7 +187,7 @@ def save_params(fname, arg_params, aux_params, logger=None): prefix, epoch = download_model(model_name=args.model, logger=logger) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) - sym = sym.get_backend_symbol('MKLDNN') + sym = sym.get_backend_symbol('MKLDNN_QUANTIZE') # get batch size batch_size = args.batch_size @@ -315,7 +315,7 @@ def save_params(fname, arg_params, aux_params, logger=None): raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' % calib_mode) sym_name = '%s-symbol.json' % (prefix + suffix) - qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') + qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') save_symbol(sym_name, qsym, logger) param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) save_params(param_name, qarg_params, aux_params, logger) diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py index 4b111dfa1875..d50935499240 100644 --- a/example/ssd/quantization.py +++ b/example/ssd/quantization.py @@ -101,7 +101,7 @@ def save_params(fname, arg_params, aux_params, logger=None): label = mx.sym.Variable(name='label') sym = mx.sym.Group([sym, label]) - sym = sym.get_backend_symbol('MKLDNN') + sym = sym.get_backend_symbol('MKLDNN_QUANTIZE') # get batch size batch_size = args.batch_size @@ -163,6 +163,6 @@ def calib_layer(name): return not (name.endswith('_data') or label_names=(label_name,), logger=logger) sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300') param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch) - qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') + qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') save_symbol(sym_name, qsym, logger) save_params(param_name, qarg_params, aux_params, logger) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index f660b97f8789..086dbc07a043 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -900,8 +900,9 @@ def load(self, filename, ctx=None, allow_missing=False, "restore_prefix is '%s' but Parameters name '%s' does not start " \ "with '%s'"%(restore_prefix, name, restore_prefix) lprefix = len(restore_prefix) + ndarray_load = ndarray.load(filename) loaded = [(k[4:] if k.startswith('arg:') or k.startswith('aux:') else k, v) \ - for k, v in ndarray.load(filename).items()] + for k, v in ndarray_load.items()] if isinstance(ndarray_load, dict) else ndarray_load arg_dict = {restore_prefix+k: v for k, v in loaded} if not allow_missing: for name in self.keys(): diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index d80fab58be42..7b46be487488 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -80,6 +80,11 @@ def get_rtol(rtol=None): # be needed for different device and dtype return 1e-5 if rtol is None else rtol +def get_etol(etol=None): + """Get default numerical threshold for regression test.""" + # _TODO: get from env variable, different threshold might + # be needed for different device and dtype + return 0 if etol is None else etol def random_arrays(*shapes): """Generate some random numpy arrays.""" @@ -494,6 +499,50 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan= names=names) raise AssertionError(msg) +def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('a', 'b'), equal_nan=False): + """Test that two numpy arrays are almost equal within given error rate. Raise exception message if not. + + Parameters + ---------- + a : np.ndarray + b : np.ndarray + threshold : None or float + The checking threshold. Default threshold will be used if set to ``None``. + etol : None or float + The error rate threshold. If etol is float, return true if error_rate < etol even if + any error is found. + """ + rtol = get_rtol(rtol) + atol = get_atol(atol) + etol = get_etol(etol) + if etol: + equals = np.isclose(a, b, rtol=rtol, atol=atol) + err = 1 - np.count_nonzero(equals) / equals.size + if err > etol: + #if True: + index, rel = find_max_violation(a, b, rtol, atol) + np.set_printoptions(threshold=4, suppress=True) + msg = npt.build_err_msg([a, b], + err_msg="Error %f exceeds tolerance rtol=%f, atol=%f, etol=%f." + " Error_rate=%f. Location of maximum error:%s, a=%f, b=%f" + % (rel, rtol, atol, etol, err, str(index), a[index], b[index]), + names=names) + raise AssertionError(msg) + + if almost_equal(a, b, rtol, atol, equal_nan=equal_nan): + return + else: + if almost_equal(a, b, rtol, atol, equal_nan=equal_nan): + return + index, rel = find_max_violation(a, b, rtol, atol) + np.set_printoptions(threshold=4, suppress=True) + msg = npt.build_err_msg([a, b], + err_msg="Error %f exceeds tolerance rtol=%f, atol=%f. " + " Location of maximum error:%s, a=%f, b=%f" + % (rel, rtol, atol, str(index), a[index], b[index]), + names=names) + raise AssertionError(msg) + def almost_equal_ignore_nan(a, b, rtol=None, atol=None): """Test that two NumPy arrays are almost equal (ignoring NaN in either array). diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index 9ce27fad4b19..e4c829645e13 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -57,6 +57,12 @@ bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) { return SupportMKLDNNAct(param); } +bool SupportQuantizedMKLDNNAct(const ActivationParam ¶m) { + // TODO(zhennan): Add more activation type when mkldnn supports. + // Remove this when it's identity to SupportMKLDNNAct. + return param.act_type == activation::kReLU; +} + mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { switch (param.act_type) { case activation::kReLU: diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index b05157942131..5670983e6aa3 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -178,7 +178,8 @@ struct SoftmaxOutputParam; struct TransposeParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input); -bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); +bool SupportQuantizedMKLDNNAct(const ActivationParam ¶m); +bool SupportMKLDNNConv(const ConvolutionParam ¶ms, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index ab6650eadad7..b4289e524999 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -38,9 +38,9 @@ namespace op { struct MKLDNNConvParam : public dmlc::Parameter { bool with_bn; - bool with_relu; + bool with_act; bool with_sum; - bool with_postsum_relu; + bool with_postsum_act; bool quantized; dmlc::optional min_calib_range; // min float value calculated from calibration dataset @@ -49,12 +49,12 @@ struct MKLDNNConvParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(MKLDNNConvParam) { DMLC_DECLARE_FIELD(with_bn).set_default(false) .describe("Add post batchnorm."); - DMLC_DECLARE_FIELD(with_relu).set_default(false) - .describe("Add post relu"); + DMLC_DECLARE_FIELD(with_act).set_default(false) + .describe("Add post activation"); DMLC_DECLARE_FIELD(with_sum).set_default(false) .describe("Add post sum"); - DMLC_DECLARE_FIELD(with_postsum_relu).set_default(false) - .describe("Add post relu after sum"); + DMLC_DECLARE_FIELD(with_postsum_act).set_default(false) + .describe("Add post activation after sum"); DMLC_DECLARE_FIELD(quantized).set_default(false) .describe("enable quantization"); DMLC_DECLARE_FIELD(min_calib_range) @@ -70,18 +70,22 @@ struct MKLDNNConvParam : public dmlc::Parameter { } }; +struct MKLDNNPostActParam { + mkldnn::algorithm alg = mkldnn::algorithm::algorithm_undef; + float scale = 1.f; + float alpha = 0.f; + float beta = 1.f; +}; + struct MKLDNNConvFullParam { ConvolutionParam conv_param; MKLDNNConvParam mkldnn_param; - float sum_scale; + float sum_scale = 1.f; std::vector requantize_scales; + MKLDNNPostActParam act_param; + MKLDNNPostActParam postsum_act_param; }; -static inline bool IsOutputUInt8(const MKLDNNConvParam &mkldnn_param) { - return ((!mkldnn_param.with_sum) && mkldnn_param.with_relu) || - mkldnn_param.with_postsum_relu; -} - mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index a3aca98d9f81..d32a6a343d7d 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -82,20 +82,16 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP } mkldnn::primitive_attr attr; mkldnn::post_ops ops; - if (param.mkldnn_param.with_relu) { - float scale = 1.0f; // for fp32, scale is 1. - float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. - float beta = 1.0f; // ignored for mkldnn_eltwise_relu. - ops.append_eltwise(scale, eltwise_relu, alpha, beta); + if (param.mkldnn_param.with_act) { + const auto &act_param = param.act_param; + ops.append_eltwise(act_param.scale, act_param.alg, act_param.alpha, act_param.beta); } if (param.mkldnn_param.with_sum) { ops.append_sum(param.sum_scale); } - if (param.mkldnn_param.with_postsum_relu) { - float scale = 1.0f; // for fp32, scale is 1. - float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. - float beta = 1.0f; // ignored for mkldnn_eltwise_relu. - ops.append_eltwise(scale, eltwise_relu, alpha, beta); + if (param.mkldnn_param.with_postsum_act) { + const auto &act_param = param.postsum_act_param; + ops.append_eltwise(act_param.scale, act_param.alg, act_param.alpha, act_param.beta); } attr.set_post_ops(ops); diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 28b89613ee86..e0fb615a7ac0 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -299,6 +299,7 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph } LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name << ". Excluding nodes " << excluded_node_names << "and retrying"; + subgraph_selector->Reset(); } ++count; } @@ -509,9 +510,9 @@ void FindOutputEntries(nnvm::Graph* g, void CutGraphInputs(const std::vector &input_entries, std::vector *orig_entries, const bool skip_var = false) { - orig_entries->reserve(input_entries.size()); + orig_entries->resize(input_entries.size()); // map for creating unique var nodes for deduplicating entries from the same node - std::unordered_map new_node_map; + std::unordered_map name_count_map; for (size_t i = 0; i < input_entries.size(); ++i) { nnvm::NodeEntry *e = input_entries[i]; // If the node is a variable itself, we may want to skip the node. @@ -519,17 +520,19 @@ void CutGraphInputs(const std::vector &input_entries, continue; } + orig_entries->at(i) = *e; nnvm::Symbol sym; sym.outputs.push_back(*e); const auto output_names = sym.ListOutputNames(); CHECK_EQ(output_names.size(), 1U); const std::string& var_name = output_names[0]; - auto it = new_node_map.find(var_name); - if (it == new_node_map.end()) { - orig_entries->push_back(*e); - new_node_map[var_name] = nnvm::CreateVariableNode(var_name); + auto it = name_count_map.find(var_name); + if (name_count_map.end() == it) { + name_count_map.emplace(var_name, 0); + } else { + ++(it->second); } - nnvm::NodePtr n = new_node_map[var_name]; + nnvm::NodePtr n = nnvm::CreateVariableNode(var_name + std::to_string(name_count_map[var_name])); *e = nnvm::NodeEntry{n, 0, 0}; } } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h index b44f2fb0e31e..fcf767adebad 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h @@ -21,11 +21,12 @@ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ #if MXNET_USE_MKLDNN == 1 +#include #include #include -#include -#include "../../nn/convolution-inl.h" +#include "../../nn/activation-inl.h" #include "../../nn/batch_norm-inl.h" +#include "../../nn/convolution-inl.h" #include "../../nn/mkldnn/mkldnn_convolution-inl.h" namespace mxnet { @@ -36,6 +37,25 @@ struct MKLDNNConvFusionParam { std::shared_ptr bn_param; }; +static inline bool IsOutputUInt8(const MKLDNNConvFusionParam& param) { + bool result = false; + const auto& mkldnn_param = param.full_conv_param.mkldnn_param; + auto IsOutputUInt8Helper = [](const mkldnn::algorithm& act_alg) { + return (act_alg == mkldnn::algorithm::eltwise_relu || + act_alg == mkldnn::algorithm::eltwise_logistic || + act_alg == mkldnn::algorithm::eltwise_soft_relu || + act_alg == mkldnn::algorithm::eltwise_bounded_relu); + }; + if ((!mkldnn_param.with_sum) && mkldnn_param.with_act) { + CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::algorithm_undef); + result = IsOutputUInt8Helper(param.full_conv_param.act_param.alg); + } else if (mkldnn_param.with_postsum_act) { + CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::algorithm_undef); + result = IsOutputUInt8Helper(param.full_conv_param.postsum_act_param.alg); + } + return result; +} + enum MKLDNNConvOpOutputs { kOut, kMin, kMax }; } // namespace op diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 2c05fda9a879..b7776d648e18 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -27,6 +27,8 @@ #include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../../quantization/quantization_utils.h" #include "mkldnn_conv-inl.h" +#include "../../nn/mkldnn/mkldnn_act-inl.h" +#include "../../tensor/matrix_op-inl.h" namespace mxnet { namespace op { @@ -294,7 +296,6 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, cached_data_max_ = data_max; cached_sum_min_ = sum_min; cached_sum_max_ = sum_max; - full_conv_param.sum_scale = 1.0; cached_weight_ = inputs[in_weight].Reorder2Default(); weight_ver_ = inputs[in_weight].version(); if (!conv_param.no_bias) { @@ -348,7 +349,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_); } if (post_requantize_) { - quantized_out_range = IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range; + quantized_out_range = IsOutputUInt8(param_) ? kUint8Range : kInt8Range; out_range = MaxAbs(cached_output_min_, cached_output_max_); output_scale = quantized_out_range / out_range; full_conv_param.requantize_scales.resize(weight_channelwise_scale ? channel : 1); @@ -373,6 +374,19 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, if (mkldnn_param.with_sum) { full_conv_param.sum_scale = output_scale / sum_in_scale; } + if (mkldnn_param.with_act && + full_conv_param.act_param.alg == mkldnn::algorithm::eltwise_bounded_relu) { + if (mkldnn_param.with_sum) { + LOG(ERROR) << "mkldnn doesn't support conv + relu + sum fusion yet."; + full_conv_param.act_param.alpha *= output_scale; + } else { + // For conv+relu6 without sum, we don't need post_ops as output_scale can do the cut off. + mkldnn_param.with_act = false; + } + } + if (mkldnn_param.with_postsum_act) { + CHECK(full_conv_param.postsum_act_param.alg == mkldnn::algorithm::eltwise_relu); + } } fwd_.reset(new MKLDNNConvForward( full_conv_param, ctx.is_train, data, cached_weight_, @@ -385,6 +399,25 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, initialized_ = true; } + if (mkldnn_param.with_sum) { + const auto output_mem = output.GetMKLDNNData(); + const auto out_mem_desc = output_mem->get_primitive_desc().desc(); + const auto dst_format = fwd_->fwd_pd.dst_primitive_desc().desc().data.format; + if (out_mem_desc.data.format != dst_format) { + auto tmp_out_mem = output.GetMKLDNNDataReorder(fwd_->fwd_pd.dst_primitive_desc()); + mkldnn::memory::desc data_md( + mkldnn::memory::dims(out_mem_desc.data.dims, + out_mem_desc.data.dims + out_mem_desc.data.ndims), + static_cast(out_mem_desc.data.data_type), + static_cast(dst_format)); + mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); + mkldnn_mem_ptr new_out_mem(new mkldnn::memory(pd, output_mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(new_out_mem); + mxnet::MKLDNNCopy(*tmp_out_mem, new_out_mem.get()); + output = NDArray(new_out_mem); + } + } + if (mkldnn_param.quantized) { auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_primitive_desc()); mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->fwd_pd.dst_primitive_desc()); @@ -437,6 +470,23 @@ static uint32_t SgMKLDNNConvNumInputs(const NodeAttrs &attrs) { static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { MKLDNNConvFusionParam param_; + + // For back-compatible, rename + // with_relu -> with_act + // with_postsum_relu -> with_postsum_act + + auto old = attrs->dict.find("with_relu"); + if (old != attrs->dict.end()) { + attrs->dict["with_act"] = old->second; + attrs->dict.erase(old); + } + + old = attrs->dict.find("with_postsum_relu"); + if (old != attrs->dict.end()) { + attrs->dict["with_postsum_act"] = old->second; + attrs->dict.erase(old); + } + try { param_.full_conv_param.mkldnn_param.Init(attrs->dict); } catch (const dmlc::ParamError &e) { @@ -452,6 +502,7 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { } CHECK_EQ(attrs->subgraphs.size(), 1); auto subgraph_sym = attrs->subgraphs[0]; + bool with_act = false; DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) { if (node->is_variable()) return; auto &node_name = node->op()->name; @@ -463,6 +514,20 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { } else if (node_name == "Convolution") { param_.full_conv_param.conv_param = nnvm::get(node->attrs.parsed); + } else if (node_name == "Activation" || node_name == "clip") { + auto &post_act_param = + (param_.full_conv_param.mkldnn_param.with_act && !with_act) + ? param_.full_conv_param.act_param + : param_.full_conv_param.postsum_act_param; + with_act = true; + if (node_name == "Activation") { + const auto act_param = nnvm::get(node->attrs.parsed); + post_act_param.alg = GetMKLDNNActAlgo(act_param); + } else { + const auto clip_param = nnvm::get(node->attrs.parsed); + post_act_param.alg = mkldnn::algorithm::eltwise_bounded_relu; + post_act_param.alpha = clip_param.a_max; + } } }); attrs->parsed = std::move(param_); @@ -605,7 +670,7 @@ static bool SgMKLDNNConvInferType(const nnvm::NodeAttrs &attrs, } if (param.full_conv_param.mkldnn_param.min_calib_range.has_value() && param.full_conv_param.mkldnn_param.max_calib_range.has_value()) { - if (IsOutputUInt8(param.full_conv_param.mkldnn_param)) { + if (IsOutputUInt8(param)) { TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); } else { TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h index 7fe4727a4990..a39b4ebe4fc5 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -23,15 +23,17 @@ #include #include +#include "../../nn/activation-inl.h" +#include "../../nn/mkldnn/mkldnn_ops-inl.h" +#include "../../tensor/matrix_op-inl.h" #include "../common.h" #include "../subgraph_property.h" -#include "../../nn/activation-inl.h" namespace mxnet { namespace op { class SgMKLDNNConvSelector : public SubgraphSelector { public: - /*! \brief pattern match status */ + /*! \brief pattern match status_ */ enum SelectStatus { kFail = 0, kStart, @@ -41,25 +43,28 @@ class SgMKLDNNConvSelector : public SubgraphSelector { }; private: - bool disable_all; - bool disable_conv_bn; - bool disable_conv_relu; - bool disable_conv_sum; - SelectStatus status; - std::vector matched_list; + bool disable_all_; + bool disable_conv_bn_; + bool disable_conv_act_; + bool disable_conv_sum_; + bool quantize_; + SelectStatus status_; + std::vector matched_list_; public: - SgMKLDNNConvSelector(int dis_all, int dis_conv_bn, int dis_conv_relu, int dis_conv_sum) - : disable_all(dis_all), - disable_conv_bn(dis_conv_bn), - disable_conv_relu(dis_conv_relu), - disable_conv_sum(dis_conv_sum) {} + SgMKLDNNConvSelector(int dis_all, int dis_conv_bn, int dis_conv_act, int dis_conv_sum, + int quantize) + : disable_all_(dis_all), + disable_conv_bn_(dis_conv_bn), + disable_conv_act_(dis_conv_act), + disable_conv_sum_(dis_conv_sum), + quantize_(quantize) {} bool Select(const nnvm::Node &n) override { if (n.op() && n.op()->name == "Convolution") { - status = disable_all ? kSuccess : kStart; - matched_list.clear(); - matched_list.push_back(&n); + status_ = disable_all_ ? kSuccess : kStart; + matched_list_.clear(); + matched_list_.push_back(&n); return true; } return false; @@ -72,60 +77,72 @@ class SgMKLDNNConvSelector : public SubgraphSelector { bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. - if (matched_list.back() != &n) { - if (std::find(matched_list.begin(), matched_list.end(), &n) != - matched_list.end()) { - while (matched_list.back() != &n) { - matched_list.pop_back(); + if (matched_list_.back() != &n) { + if (std::find(matched_list_.begin(), matched_list_.end(), &n) != + matched_list_.end()) { + while (matched_list_.back() != &n) { + matched_list_.pop_back(); } } - status = kSuccess; + status_ = kSuccess; return false; } - if (status == kFail || status == kSuccess || new_node.is_variable()) + if (status_ == kFail || status_ == kSuccess || new_node.is_variable()) return false; - // Use status machine to do selection. The status change is + // Use status_ machine to do selection. The status_ change is // kStart -> kBN -> kSum -> kSuccess - switch (status) { + switch (status_) { case kStart: - if ((!disable_conv_bn) && new_node.op()->name == "BatchNorm") { - matched_list.push_back(&new_node); - status = kBN; + if ((!disable_conv_bn_) && new_node.op()->name == "BatchNorm") { + matched_list_.push_back(&new_node); + status_ = kBN; return true; } case kBN: - if ((!disable_conv_sum) && new_node.op()->name == "elemwise_add") { - matched_list.push_back(&new_node); - status = kSum; + if ((!disable_conv_sum_) && new_node.op()->name == "elemwise_add") { + matched_list_.push_back(&new_node); + status_ = kSum; return true; } case kSum: default: - if ((!disable_conv_relu) && new_node.op()->name == "Activation") { + if ((!disable_conv_act_) && new_node.op()->name == "Activation") { const ActivationParam ¶m = nnvm::get(new_node.attrs.parsed); - if (param.act_type == activation::kReLU) { - matched_list.push_back(&new_node); - // If we find conv+relu, then we can't match anymore. - // TODO(zhennan): mkldnn only supports convolution + relu + sum in - // int8, not in fp32. So we disable this pattern at moment. - status = kSuccess; + if ((quantize_ && SupportQuantizedMKLDNNAct(param)) || + (!quantize_ && SupportMKLDNNAct(param))) { + matched_list_.push_back(&new_node); + // not support conv+relu+sum yet. + status_ = kSuccess; return true; } + } else if ((!disable_conv_act_) && new_node.op()->name == "clip") { + if (!(quantize_ && (status_ == kSum))) { + // TODO(zhennan): doesn't support int8 conv+sum+relu6 at moment. To support this, we + // need to fuse conv+sum first, and calibrate with it. Then fuse int8 relu6 into fused + // conv. + const ClipParam ¶m = nnvm::get(new_node.attrs.parsed); + if (param.a_min == 0.f) { + matched_list_.push_back(&new_node); + // not support conv+relu+sum yet. + status_ = kSuccess; + return true; + } + } } - status = kSuccess; + status_ = kSuccess; return false; } } std::vector Filter( const std::vector &candidates) override { - if (status == kFail) { + if (status_ == kFail) { return std::vector(0); } else { std::vector ret; - for (auto i : matched_list) { + for (auto i : matched_list_) { auto non_const_i = const_cast(i); if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) { @@ -135,16 +152,24 @@ class SgMKLDNNConvSelector : public SubgraphSelector { return ret; } } + + void Reset() override { + CHECK_GE(matched_list_.size(), 1); + auto new_selector = SgMKLDNNConvSelector(disable_all_, disable_conv_bn_, disable_conv_act_, + disable_conv_sum_, quantize_); + new_selector.Select(*matched_list_[0]); + *this = new_selector; + } }; class SgMKLDNNConvProperty : public SubgraphProperty { public: SgMKLDNNConvProperty() { - disable_conv_bn = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_BN", 0); - disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0); - disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0); + disable_conv_bn_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_BN", 0); + disable_conv_act_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0); + disable_conv_sum_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0); - disable_all = disable_conv_bn && disable_conv_relu && disable_conv_sum; + disable_all_ = disable_conv_bn_ && disable_conv_act_ && disable_conv_sum_; } static SubgraphPropertyPtr Create() { static const std::string &name = "MKLDNN convolution optimization pass"; @@ -180,12 +205,12 @@ class SgMKLDNNConvProperty : public SubgraphProperty { n->attrs.dict["with_sum"] = "true"; _with_sum = true; - } else if (sub_name == "Activation") { - node_name << "relu_"; + } else if (sub_name == "Activation" || sub_name == "clip") { + node_name << "act_"; if (!_with_sum) { - n->attrs.dict["with_relu"] = "true"; + n->attrs.dict["with_act"] = "true"; } else { - n->attrs.dict["with_postsum_relu"] = "true"; + n->attrs.dict["with_postsum_act"] = "true"; } } }); @@ -199,8 +224,9 @@ class SgMKLDNNConvProperty : public SubgraphProperty { } SubgraphSelectorPtr CreateSubgraphSelector() const override { + int quantize = HasAttr("quantize") ? GetAttr("quantize") : 0; auto selector = std::make_shared( - disable_all, disable_conv_bn, disable_conv_relu, disable_conv_sum); + disable_all_, disable_conv_bn_, disable_conv_act_, disable_conv_sum_, quantize); return selector; } @@ -241,10 +267,10 @@ class SgMKLDNNConvProperty : public SubgraphProperty { } private: - int disable_all; - int disable_conv_bn; - int disable_conv_relu; - int disable_conv_sum; + int disable_all_; + int disable_conv_bn_; + int disable_conv_act_; + int disable_conv_sum_; }; } // namespace op diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h index f8d7ee1da6c9..8b5c08802986 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h @@ -132,6 +132,13 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { return ret; } } + + void Reset() override { + CHECK_GE(matched_list.size(), 1); + auto new_selector = SgMKLDNNFCPostQuantizeSelector(disable_all, disable_float_output); + new_selector.Select(*matched_list[0]); + *this = new_selector; + } }; class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index 04e140c72d86..136fcb32335a 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -115,6 +115,13 @@ class SgMKLDNNFCSelector : public SubgraphSelector { return candidates; } } + + void Reset() override { + CHECK_GE(matched_list.size(), 1); + auto new_selector = SgMKLDNNFCSelector(disable_fc_relu); + new_selector.Select(*matched_list[0]); + *this = new_selector; + } }; class SgMKLDNNFCProperty : public SubgraphProperty { diff --git a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h index f8c47f0ce036..5c5037e7a116 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h @@ -34,6 +34,7 @@ class SgMKLDNNConcatPostQuantizeSelector : public SubgraphSelectorV2 { bool Select(const BiDirectedNode &sn) override { const auto &n = *sn.node; if (n.op() == Op::Get("_contrib_quantized_concat")) { + head_ = sn; matched_list_.clear(); visit_list_.clear(); visit_list_.insert(&n); @@ -97,7 +98,14 @@ class SgMKLDNNConcatPostQuantizeSelector : public SubgraphSelectorV2 { } } + void Reset() override { + auto new_selector = SgMKLDNNConcatPostQuantizeSelector(); + new_selector.Select(head_); + *this = new_selector; + } + private: + BiDirectedNode head_; bool select_output_; std::vector matched_list_; std::unordered_set visit_list_; @@ -127,7 +135,7 @@ class SgMKLDNNPostQuantizeAlignScaleProperty : public SubgraphProperty { * conv4 = mx.symbol.Convolution(data=data, weight=weight * 4, name='conv4', num_filter=64, * kernel=(3, 3), stride=(1, 1), no_bias=True) * concat = mx.symbol.Concat(*[conv1, conv2, conv3, conv4], name="concat", dim=1) - * + * * This pass will collect the maximum calib range from conv1 to conv4, and apply it to all * conv1 to conv4. Then concat don't need extra scale alignment operation. Performance and * accuracy are both improved. diff --git a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h index b61a303757b3..e78b8d1bfa42 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h @@ -107,6 +107,13 @@ class SgMKLDNNPostQuantizeSelector : public SubgraphSelector { return candidates; } } + + void Reset() override { + CHECK_GE(matched_list.size(), 1); + auto new_selector = SgMKLDNNPostQuantizeSelector(); + new_selector.Select(*matched_list[0]); + *this = new_selector; + } }; class SgMKLDNNPostQuantizeProperty : public SubgraphProperty { diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index 4fc2d2c024bf..7fbc859cc8d1 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -29,10 +29,20 @@ namespace mxnet { namespace op { MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); + MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); -MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNPostQuantizeProperty); -MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); -MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty); + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty) +.set_attr("quantize", true); + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty) +.set_attr("quantize", true); + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty); + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index aac3b3f2d0fc..460055f9ed86 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -104,6 +104,11 @@ class SubgraphSelector { virtual std::vector Filter(const std::vector& candidates) { return candidates; } + /*! + * \brief Reset the state of selector for SelectInput. + * Note: the state should reset to Select() is successful. + */ + virtual void Reset() {} }; using SubgraphSelectorPtr = std::shared_ptr; @@ -141,6 +146,12 @@ class SubgraphSelectorV2 { const std::vector& candidates) { return candidates; } + + /*! + * \brief Reset the state of selector for SelectInput. + * Note: the state should reset to Select() is successful. + */ + virtual void Reset() {} }; using SubgraphSelectorV2Ptr = std::shared_ptr; @@ -179,6 +190,8 @@ class SubgraphSelectorV2Bridge : public SubgraphSelectorV2 { return ret; } + void Reset() override { ss_ptr_->Reset(); } + const SubgraphSelectorPtr& GetV1ptr() const { return ss_ptr_; } private: @@ -257,7 +270,7 @@ class SubgraphProperty { /*! * \brief Adjust nnvm nodes from a given subgraph. No new node is created, but adjust * selected nodes' attributes. This can be used to implement peephole optimization. - * Here users can customize how to adjust the operators in the subgraph. + * Here users can customize how to adjust the operators in the subgraph. * \param subgraph_nodes the subgraph nodes to adjust * \param subgraph_selector The selector used for selecting this node set. * \param subgraph_id subgraph id @@ -329,6 +342,18 @@ class SubgraphProperty { using SubgraphPropertyPtr = std::shared_ptr; +class SubgraphPropertyEntry { + public: + explicit SubgraphPropertyEntry(std::shared_ptr entry) : entry_(entry) {} + SubgraphPropertyEntry set_attr(const std::string& name, const int value) const { + entry_->SetAttr(name, value); + return *this; + } + + private: + std::shared_ptr entry_; +}; + class SubgraphPropertyRegistry { public: typedef SubgraphPropertyPtr (*SubgraphPropertyCreateFn)(void); @@ -338,33 +363,22 @@ class SubgraphPropertyRegistry { } std::vector CreateSubgraphProperty(const std::string& name) { - auto it = prop_fn_map_.find(name); - CHECK(it != prop_fn_map_.end()) << "SubgraphProperty " << name + auto it = prop_ptr_map_.find(name); + CHECK(it != prop_ptr_map_.end()) << "SubgraphProperty " << name << " is not found in SubgraphPropertyRegistry"; - std::vector ret; - ret.reserve(it->second.size()); - for (auto i : it->second) { - auto ptr_it = prop_ptr_map_.find(i); - if (ptr_it == prop_ptr_map_.end()) { - prop_ptr_map_[i] = i(); - ptr_it = prop_ptr_map_.find(i); - } - if (ptr_it->second) ret.emplace_back(ptr_it->second); - } - return ret; + return it->second; } - SubgraphPropertyCreateFn __REGISTER__(const std::string& name, SubgraphPropertyCreateFn fn) { - prop_fn_map_[name].push_back(fn); - return fn; + SubgraphPropertyEntry __REGISTER__(const std::string& name, SubgraphPropertyCreateFn fn) { + prop_ptr_map_[name].emplace_back(fn()); + return SubgraphPropertyEntry(prop_ptr_map_[name].back()); } SubgraphPropertyRegistry() = default; SubgraphPropertyRegistry(const SubgraphPropertyRegistry&) = delete; SubgraphPropertyRegistry(SubgraphPropertyRegistry&&) = delete; SubgraphPropertyRegistry& operator=(const SubgraphPropertyRegistry&) = delete; - std::unordered_map> prop_fn_map_; - std::unordered_map prop_ptr_map_; + std::unordered_map> prop_ptr_map_; }; // This op name set is for setting the names of operators that should be grouped into diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 761eb47e56cb..7690ee1baabe 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -21,7 +21,6 @@ import numpy as np import unittest import ctypes -from mxnet.io import NDArrayIter from mxnet.module import Module from mxnet.symbol import Symbol from importlib import import_module @@ -31,38 +30,34 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../unittest/')) from common import with_seed -from mxnet.test_utils import assert_almost_equal +from mxnet.test_utils import assert_almost_equal, assert_almost_equal_with_err import itertools OP_NAME='op_name' QUANTIZED_OP_NAME='quantized_op_name' -SG_PASS_NAME='sg_pass_name' -POST_SG_PASS_NAME='post_sg_pass_name' +SG_PASS_NAME='MKLDNN' +QUANTIZE_SG_PASS_NAME='MKLDNN_QUANTIZE' config = { 'conv': { OP_NAME: 'sg_mkldnn_conv', - QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv', - SG_PASS_NAME: 'MKLDNN', - POST_SG_PASS_NAME: 'MKLDNN_POST_QUANTIZE' + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv' }, 'fc': { OP_NAME: 'sg_mkldnn_fully_connected', - QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected', - SG_PASS_NAME: 'MKLDNN', - POST_SG_PASS_NAME: 'MKLDNN_POST_QUANTIZE' + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected' } } -DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] +DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)] def check_qsym_calibrated(qsym, out_type, name='conv'): - quantized_op_name = config[name][QUANTIZED_OP_NAME] + quantized_op_name = 'quantized_' + name assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1 for k, v in qsym.attr_dict().items(): if k.find('_quantize') != -1: assert v['out_type'] == out_type if k.find(quantized_op_name) != -1: - if name == 'fc' and 'enable_float_output' in v: + if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v: continue assert 'min_calib_range' in v assert 'max_calib_range' in v @@ -84,22 +79,20 @@ def check_qsym_scale_align(qsym): -def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape): - mod = Module(symbol=qsym, context=mx.current_context()) +def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape): + mod = Module(symbol=qsym, label_names=None, context=mx.current_context()) mod.bind(for_training=False, - data_shapes=[('data', data_shape)], - label_shapes=[('softmax_label', label_shape)]) + data_shapes=[('data', data_shape)]) mod.set_params(qarg_params, qaux_params) mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() return mod.get_outputs() -def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape): - mod = Module(symbol=qsym, context=mx.current_context()) +def check_qsym_dummy_forward(qsym, batch, data_shape): + mod = Module(symbol=qsym, label_names=None, context=mx.current_context()) mod.bind(for_training=False, - data_shapes=[('data', data_shape)], - label_shapes=[('softmax_label', label_shape)]) + data_shapes=[('data', data_shape)]) mod.init_params(initializer=mx.init.Xavier(magnitude=2.)) mod.forward(batch, is_train=False) for output in mod.get_outputs(): @@ -121,30 +114,34 @@ def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape): data = mx.random.uniform(-1.0, 1.0, shape=data_shape) net(data) +class CalibIter(mx.io.DataIter): + def __init__(self, batch, data_shape, batch_size): + super(CalibIter, self).__init__(batch_size) + self.data_shape = data_shape + self.label_shape = (batch_size,) + self.provide_data = [('data', self.data_shape)] + self.provide_label = [] + self.batch = batch + + def __iter__(self): + yield self.batch + + def check_quantize(sym, data_shape, out_type, name='conv', check_calibration=True, gluon_forward=False, check_scale_align=False): - sg_pass_name = config[name][SG_PASS_NAME] - post_sg_pass_name = config[name][POST_SG_PASS_NAME] - - fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc_softmax') - if gluon_forward == True: - sym = fc - sym_sg = sym.get_backend_symbol(sg_pass_name) - mod = Module(symbol=sym, label_names=[]) - mod.bind(for_training=False, + if name in config: + name = config[name][OP_NAME] + sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) + mod = Module(symbol=sym, label_names=None) + mod.bind(for_training=False, data_shapes=[('data', data_shape)]) - else: - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - sym_sg = sym.get_backend_symbol(sg_pass_name) - label_shape = (data_shape[0], 10) - mod = Module(symbol=sym) - mod.bind(for_training=False, - data_shapes=[('data', data_shape)], - label_shapes=[('softmax_label', label_shape)]) mod.init_params(mx.init.Normal(0.5)) arg_params, aux_params = mod.get_params() - data = [mx.random.uniform(-1, 1, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] + if out_type == 'uint8': + data = [mx.random.uniform(0.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] + else: + data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] batch = mx.io.DataBatch(data, []) mod.forward(batch, is_train=False) @@ -157,10 +154,8 @@ def check_quantize(sym, data_shape, out_type, name='conv', excluded_sym_names += ['sg_mkldnn_fully_connected_0'] excluded_sym_names += ['fc_softmax'] - calib_data = mx.nd.random.uniform(shape=data_shape) - calib_data = NDArrayIter(data=calib_data) - calib_data = DummyIter(calib_data) - calib_layer = lambda name: name.endswith('_output') + calib_data = CalibIter(batch, data_shape, 1) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, arg_params=arg_params, aux_params=aux_params, @@ -169,9 +164,10 @@ def check_quantize(sym, data_shape, out_type, name='conv', quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, - calib_layer=calib_layer, - num_calib_examples=5) - qsym = qsym.get_backend_symbol(post_sg_pass_name) + calib_layer=None, + label_names=None, + num_calib_examples=1) + qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) if check_scale_align: @@ -179,10 +175,13 @@ def check_quantize(sym, data_shape, out_type, name='conv', if gluon_forward == True: check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape) else: - check_qsym_dummy_forward(qsym, batch, data_shape, label_shape) - quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) + quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape) for i in range(len(ref_out)): - assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) + min_range = mx.nd.min(ref_out[i]).asscalar() + max_range = mx.nd.max(ref_out[i]).asscalar() + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) + check_qsym_dummy_forward(qsym, batch, data_shape) @with_seed() def check_quantize_whole_model_with_forward(): @@ -203,8 +202,8 @@ def check_quantize_whole_model(out_type): data = mx.sym.Variable('data') conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0') sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1') - sym_sg = sym.get_backend_symbol('MKLDNN') - mod = Module(symbol=sym, label_names=[]) + sym_sg = sym.get_backend_symbol('MKLDNN_QUANTIZE') + mod = Module(symbol=sym, label_names=None) mod.bind(for_training=False, data_shapes=[('data', data_shape)]) @@ -214,7 +213,7 @@ def check_quantize_whole_model(out_type): excluded_sym_names = [] calib_data = mx.nd.random.uniform(shape=data_shape) - calib_data = NDArrayIter(data=calib_data) + calib_data = mx.io.NDArrayIter(data=calib_data) calib_data = DummyIter(calib_data) calib_layer = lambda name: name.endswith('_output') qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, @@ -226,56 +225,61 @@ def check_quantize_whole_model(out_type): calib_mode='naive', calib_data=calib_data, calib_layer=calib_layer, - num_calib_examples=5) - qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') + label_names=None, + num_calib_examples=1) + qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') check_qsym_forward(qsym, qarg_params, qaux_params, data_shape) for qdtype in ['uint8', 'int8', 'auto']: check_quantize_whole_model(qdtype) @with_seed() -def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True): - op_name = config[name][OP_NAME] - sg_pass_name = config[name][SG_PASS_NAME] - - sym_sg = sym.get_backend_symbol(sg_pass_name) - assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1 - for k, v in sym_sg.attr_dict().items(): - if k.find(op_name) != -1: - for attr_op in attrs_op: - assert v[attr_op] in ['true', 'True'] - - arg_shapes, _, aux_shapes = sym.infer_shape() - arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes] - aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] - exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') - exe.forward() - os.environ['MXNET_SUBGRAPH_BACKEND'] = sg_pass_name - exe_sg = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') - exe_sg.forward() - del os.environ['MXNET_SUBGRAPH_BACKEND'] - for i in range(len(exe.outputs)): - assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3) - - # fp32 to int8 - out_type_list = ['uint8', 'int8', 'auto'] +def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quantization=True, out_types=['uint8', 'int8', 'auto']): + if check_fp32_fusion: + sym_sg = sym.get_backend_symbol(SG_PASS_NAME) + for name, attrs in attrs_dict.items(): + if name in config: + op_name = config[name][OP_NAME] + else: + op_name = name + assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1 + if len(attrs): + found = False + for k, v in sym_sg.attr_dict().items(): + if k.find(op_name) != -1: + found = True + for attr_name, attr_value in attrs.items(): + assert v[attr_name].lower() == attr_value.lower() + assert found + + arg_shapes, _, aux_shapes = sym.infer_shape() + arg_array = [mx.nd.random.uniform(-1.0, 1.0, shape=shape) for shape in arg_shapes] + aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] + exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe.forward() + os.environ['MXNET_SUBGRAPH_BACKEND'] = SG_PASS_NAME + exe_sg = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe_sg.forward() + del os.environ['MXNET_SUBGRAPH_BACKEND'] + for i in range(len(exe.outputs)): + assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-1) if check_quantization: - for out_type in out_type_list: - check_quantize(sym, data_shape, out_type, name=name) + # fp32 to int8 + for out_type in out_types: + check_quantize(sym, data_shape, out_type, name=op_name) # TODO(ciyong), since quantized fc save its params in int8, while gluon treat the default # variable from symbol file as fp32 which results in mismatch dtype of params. # Skip quantized fc in gluon pass. if name != 'fc': - check_quantize(sym, data_shape, out_type, name=name, gluon_forward=True) + check_quantize(sym, data_shape, out_type, name=op_name, gluon_forward=True) def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10), name='conv'): op_name = config[name][OP_NAME] - sg_pass_name = config[name][SG_PASS_NAME] for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): - sym_sg = sym.get_backend_symbol(sg_pass_name) + sym_sg = sym.get_backend_symbol(SG_PASS_NAME) exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') attrs_dict = sym_sg.attr_dict() @@ -289,38 +293,55 @@ def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, def head_symbol(data_shape): data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') weight = mx.symbol.Variable('weight', dtype='float32') - bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') - return bn, weight + return data, weight # single conv fuision case def single_conv(no_bias, data_shape): - conv_attr = [''] + attr = {'conv': []} data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) - return conv, conv_attr + return conv, attr # conv + bn fusion case def conv_bn(no_bias, data_shape): - conv_bn_attr = ['with_bn'] + attr = {'conv': {'with_bn': 'true'}} data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") - return bn1, conv_bn_attr + return bn1, attr -# conv + relu fusion case -def conv_relu(no_bias, data_shape): - conv_relu_attr = ['with_relu'] +# conv + act fusion case +def conv_act(no_bias, data_shape, alg): + attr = {'conv': {'with_act': 'true'}} data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) - relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") - return relu, conv_relu_attr + if alg == "relu6": + relu = mx.symbol.clip(data=conv, name='relu6', a_min=0, a_max=6) + else: + relu = mx.symbol.Activation(data=conv, name='relu', act_type=alg) + return relu, attr + +# conv + act + sum fusion case +def conv_act_sum(no_bias, data_shape, alg): + attr = {'conv': {'with_act': 'true', 'with_sum': 'true'}} + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + if alg == "relu6": + relu = mx.symbol.clip(data=conv, name='relu6', a_min=0, a_max=6) + else: + relu = mx.symbol.Activation(data=conv, name='relu', act_type=alg) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + sum = relu + conv1 + return sum, attr # conv + add fusion case def conv_add(no_bias, data_shape): - conv_add_attr = ['with_sum'] + attr = {'conv': {'with_sum': 'true'}} data, weight = head_symbol(data_shape) conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) @@ -328,11 +349,11 @@ def conv_add(no_bias, data_shape): kernel=(3, 3), stride=(1, 1)) pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') sum = conv1 + pool - return sum, conv_add_attr + return sum, attr # conv + add fusion case 2 def conv_add2(no_bias, data_shape): - conv_add_attr = ['with_sum'] + attr = {'conv': {'with_sum': 'true'}} data, weight = head_symbol(data_shape) conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) @@ -340,21 +361,24 @@ def conv_add2(no_bias, data_shape): kernel=(3, 3), stride=(1, 1)) pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') sum = pool + conv1 - return sum, conv_add_attr + return sum, attr -# conv + bn + relu fusion case -def conv_bn_relu(no_bias, data_shape): - conv_bn_relu_attr = ['with_bn', 'with_relu'] +# conv + bn + act fusion case +def conv_bn_act(no_bias, data_shape, alg): + attr = {'conv': {'with_bn': 'true', 'with_act': 'true'}} data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") - relu = mx.symbol.Activation(data=bn1, name='relu', act_type="relu") - return relu, conv_bn_relu_attr + if alg == "relu6": + relu = mx.symbol.clip(data=bn1, name='relu6', a_min=0, a_max=6) + else: + relu = mx.symbol.Activation(data=bn1, name='relu', act_type=alg) + return relu, attr -# conv + bn + add + relu fusion case -def conv_bn_sum_relu(no_bias, data_shape): - conv_bn_add_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn'] +# conv + bn + add + act fusion case +def conv_bn_sum_act(no_bias, data_shape, alg): + attr = {'conv': {'with_sum': 'true', 'with_postsum_act': 'true', 'with_bn': 'true'}} data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) @@ -362,12 +386,15 @@ def conv_bn_sum_relu(no_bias, data_shape): conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1)) sum1 = bn1 + conv1 - relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu") - return relu, conv_bn_add_relu_attr + if alg == "relu6": + relu = mx.symbol.clip(data=sum1, name='relu6', a_min=0, a_max=6) + else: + relu = mx.symbol.Activation(data=sum1, name='relu', act_type=alg) + return relu, attr # single concat case def single_concat(data_shape, input_num, dim): - data, weight = head_symbol(data_shape) + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') inputs = [] for i in range(input_num): inputs.append(data) @@ -388,6 +415,22 @@ def concat_scale_align(data_shape): concat = mx.symbol.Concat(*[conv1, conv2, conv3, conv4], name="concat", dim=1) return concat + +# mobilenetv2 case +def mobilenetv2_struct(data_shape): + attr = {'sg_mkldnn_conv_bn_0' : {'with_bn': 'true'}} + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + weight1 = mx.symbol.Variable('conv1_weight', dtype='float32') + weight2 = mx.symbol.Variable('conv2_weight', dtype='float32') + conv1 = mx.symbol.Convolution(data=data, weight=weight1, name='conv1', num_filter=64, + kernel=(1, 1), stride=(1, 1), no_bias=True) + bn1 = mx.symbol.BatchNorm(data=conv1, name="bn1") + conv2 = mx.symbol.Convolution(data=bn1, weight=weight2, name='conv2', num_filter=64, + kernel=(1, 1), stride=(1, 1), no_bias=True) + bn2 = mx.symbol.BatchNorm(data=conv2, name="bn2") + sum = bn1 + bn2 + return sum, attr + def tail_neg_symbol(sym1, sym2): fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1') fc2 = mx.sym.FullyConnected(data=sym2, num_hidden=10, flatten=True, name='fc2') @@ -504,7 +547,7 @@ def neg_conv_bn_relu(data_shape): syms.append(sym2) attrs.append(['with_bn']) - excluded_attrs.append(['with_relu']) + excluded_attrs.append(['with_act']) return syms, attrs, excluded_attrs # conv + bn + add + relu can't be fusion case @@ -539,7 +582,7 @@ def neg_conv_bn_add_relu(data_shape): syms.append(sym1) attrs.append([]) - excluded_attrs.append(['with_sum', 'with_postsum_relu', 'with_bn']) + excluded_attrs.append(['with_sum', 'with_postsum_act', 'with_bn']) # eg.2 conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) @@ -551,7 +594,7 @@ def neg_conv_bn_add_relu(data_shape): syms.append(sym2) attrs.append(['with_bn']) - excluded_attrs.append(['with_sum', 'with_postsum_relu']) + excluded_attrs.append(['with_sum', 'with_postsum_act']) # eg.3 conv31 = mx.symbol.Convolution(data=data, weight=weight, name='conv31', num_filter=64, kernel=(3, 3), stride=(1, 1)) @@ -563,18 +606,18 @@ def neg_conv_bn_add_relu(data_shape): syms.append(sym3) attrs.append(['with_bn', 'with_sum']) - excluded_attrs.append(['with_postsum_relu']) + excluded_attrs.append(['with_postsum_act']) return syms, attrs, excluded_attrs def single_fc(no_bias, data_shape, flatten=True): - attr = [''] + attr = {'fc': {}} data, weight = head_symbol(data_shape) fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, no_bias=no_bias, flatten=flatten) return fc, attr def fc_relu(no_bias, data_shape, flatten=True): - attr = ['with_relu'] + attr = {'fc': {'with_relu': 'true'}} data, weight = head_symbol(data_shape) fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, no_bias=no_bias, flatten=flatten) @@ -614,12 +657,18 @@ def test_pos_single_conv(): check_fusion(net, data_shape, attrs) @with_seed() -def test_pos_conv_relu(): +def test_pos_conv_act(): + act_list = {"relu": True, + "sigmoid": False, + "tanh": False, + "softrelu": False, + "relu6": True} for data_shape in DATA_SHAPE: - net, attrs = conv_relu(False, data_shape) - check_fusion(net, data_shape, attrs) - net, attrs = conv_relu(True, data_shape) - check_fusion(net, data_shape, attrs) + for (alg, quantize) in act_list.items(): + net, attrs = conv_act(False, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) + net, attrs = conv_act(True, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) @with_seed() def test_pos_conv_bn(): @@ -646,25 +695,37 @@ def test_pos_conv_add2(): check_fusion(net, data_shape, attrs) @with_seed() -def test_pos_conv_bn_relu(): +def test_pos_conv_bn_act(): + act_list = {"relu": True, + "sigmoid": False, + "tanh": False, + "softrelu": False, + "relu6": True} for data_shape in DATA_SHAPE: - net, attrs = conv_bn_relu(False, data_shape) - check_fusion(net, data_shape, attrs) - net, attrs = conv_bn_relu(True, data_shape) - check_fusion(net, data_shape, attrs) + for (alg, quantize) in act_list.items(): + net, attrs = conv_bn_act(False, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) + net, attrs = conv_bn_act(True, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) @with_seed() -def test_pos_conv_bn_sum_relu(): +def test_pos_conv_bn_sum_act(): + act_list = {"relu": True, + "sigmoid": False, + "tanh": False, + "softrelu": False, + "relu6": False} for data_shape in DATA_SHAPE: - net, attrs = conv_bn_sum_relu(False, data_shape) - check_fusion(net, data_shape, attrs) - net, attrs = conv_bn_sum_relu(True, data_shape) - check_fusion(net, data_shape, attrs) + for (alg, quantize) in act_list.items(): + net, attrs = conv_bn_sum_act(False, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) + net, attrs = conv_bn_sum_act(True, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) @with_seed() def test_pos_single_concat(): for data_shape in DATA_SHAPE: - for out_type in ('uint8', 'int8', 'auto'): + for out_type in ('int8', 'auto'): net = single_concat(data_shape, 2, 1) check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) @@ -678,11 +739,17 @@ def test_pos_single_concat(): @with_seed() def test_pos_concat_scale_align(): for data_shape in DATA_SHAPE: - for out_type in ('uint8', 'int8', 'auto'): + for out_type in ('int8', 'auto'): net = concat_scale_align(data_shape) check_quantize(net, data_shape, out_type, check_calibration=True, check_scale_align=True) check_quantize(net, data_shape, out_type, check_calibration=True, check_scale_align=True, gluon_forward=True) +@with_seed() +def test_mobilenetv2_struct(): + for data_shape in DATA_SHAPE: + net, attrs = mobilenetv2_struct(data_shape) + check_fusion(net, data_shape, attrs, out_types=['int8', 'auto']) + @with_seed() def test_neg_conv_bn(): for data_shape in DATA_SHAPE: @@ -718,9 +785,9 @@ def test_single_fc(): for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): syms, attrs = single_fc(no_bias, dshape, flatten) if flatten is True: - check_fusion(syms, dshape, attrs, name='fc', check_quantization=True) + check_fusion(syms, dshape, attrs, check_quantization=True) else: - check_fusion(syms, dshape, attrs, name='fc', check_quantization=False) + check_fusion(syms, dshape, attrs, check_quantization=False) @with_seed() @@ -728,9 +795,9 @@ def test_fc_relu(): for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): syms, attrs = fc_relu(no_bias, dshape, flatten) if flatten is True: - check_fusion(syms, dshape, attrs, name='fc', check_quantization=True) + check_fusion(syms, dshape, attrs, check_quantization=True) else: - check_fusion(syms, dshape, attrs, name='fc', check_quantization=False) + check_fusion(syms, dshape, attrs, check_quantization=False) @with_seed() def test_neg_fc_relu():