diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index fcd5f3edeabe..27b420f9f243 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1932,6 +1932,7 @@ MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym, * \param quantized_dtype the quantized destination type for input data * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could * \param quantize_mode quantize mode to be used in quantize pass + * \param quantize_granularity quantize granularity, tensor-wise or channel-wise * \param out_num_calib_names return the number of nodes to be calibrated * \param out_calib_names return the node names to be calibrated */ @@ -1944,8 +1945,8 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, const char **excluded_op_names, const uint32_t num_offline, const char **offline_params, const char *quantized_dtype, const bool calib_quantize, - const char *quantize_mode, uint32_t* out_num_calib_names, - const char ***out_calib_names); + const char *quantize_mode, const char *quantize_granularity, + uint32_t* out_num_calib_names, const char ***out_calib_names); /*! * \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 7c0ea77dc986..237c595ad086 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -31,6 +31,7 @@ #include #include +#include #include "./base.h" #include "./ndarray.h" @@ -344,7 +345,8 @@ using FNeedRequantize = std::function; * which can handle fp32 inputs directly. */ using FAvoidQuantizeInput = std::function; + const size_t index, + const std::string quantize_granularity)>; /*! * \brief Register a function to determine if the input of a quantized operator diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index ce22fb753ace..3ceccbf42c6e 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -85,7 +85,8 @@ def _quantize_params(qsym, params, th_dict): return quantized_params def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, - offline_params=None, quantized_dtype='int8', quantize_mode='smart'): + offline_params=None, quantized_dtype='int8', quantize_mode='smart', + quantize_granularity='tensor-wise'): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -109,6 +110,9 @@ def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, The quantized destination type for input data. quantize_mode: str The mode that quantization pass to apply. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. """ num_excluded_symbols = 0 @@ -147,6 +151,7 @@ def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, c_str(quantized_dtype), ctypes.c_bool(True), c_str(quantize_mode), + c_str(quantize_granularity), ctypes.byref(size), ctypes.byref(calib_str))) calib_layer = [] @@ -459,7 +464,8 @@ def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, - quantized_dtype='int8', quantize_mode='smart', logger=None): + quantized_dtype='int8', quantize_mode='smart', + quantize_granularity='tensor-wise', logger=None): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -515,6 +521,9 @@ def quantize_model(sym, arg_params, aux_params, The mode that quantization pass to apply. Support 'full' and 'smart'. 'full' means quantize all operator if possible. 'smart' means quantization pass will smartly choice which operator should be quantized. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. logger : Object A logging object for printing information during the process of quantization. @@ -544,11 +553,15 @@ def quantize_model(sym, arg_params, aux_params, if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8`, `uint8` or `auto`' % quantized_dtype) + if quantize_granularity not in ('tensor-wise', 'channel-wise'): + raise ValueError('unkonwn quantize_granularity %s received,' + ' expected `tensor-wise` or `channel-wise`.' % quantize_granularity) qsym, calib_layer = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, excluded_operators=excluded_op_names, - offline_params=list( - arg_params.keys()), - quantized_dtype=quantized_dtype, quantize_mode=quantize_mode) + offline_params=list(arg_params.keys()), + quantized_dtype=quantized_dtype, + quantize_mode=quantize_mode, + quantize_granularity=quantize_granularity) th_dict = {} if calib_mode is not None and calib_mode != 'none': if not isinstance(ctx, Context): @@ -597,7 +610,8 @@ def quantize_model_mkldnn(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, - quantized_dtype='int8', quantize_mode='smart', logger=None): + quantized_dtype='int8', quantize_mode='smart', + quantize_granularity='tensor-wise', logger=None): """User-level API for generating a fusion + quantized model from a FP32 model w/ or w/o calibration with Intel MKL-DNN. The backend quantized operators are only enabled for Linux systems. Please do not run @@ -628,7 +642,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params, calib_mode=calib_mode, calib_data=calib_data, num_calib_examples=num_calib_examples, quantized_dtype=quantized_dtype, quantize_mode=quantize_mode, - logger=logger) + quantize_granularity=quantize_granularity, logger=logger) qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE') @@ -636,7 +650,8 @@ def quantize_model_mkldnn(sym, arg_params, aux_params, def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, - calib_mode='entropy', quantized_dtype='int8', quantize_mode='full', + calib_mode='entropy', quantized_dtype='int8', + quantize_mode='full', quantize_granularity='tensor-wise', LayerOutputCollector=None, logger=None): """User-level API for generating a quantized model from a FP32 model w/o calibration and a collector for naive or entropy calibration. @@ -676,6 +691,9 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), The mode that quantization pass to apply. Support 'full' and 'smart'. 'full' means quantize all operator if possible. 'smart' means quantization pass will smartly choice which operator should be quantized. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. LayerOutputCollector : class For customize calibration method usage. logger : Object @@ -700,12 +718,16 @@ def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8`, `uint8` or `auto`' % quantized_dtype) + if quantize_granularity not in ('tensor-wise', 'channel-wise'): + raise ValueError('unkonwn quantize_granularity %s received,' + ' expected `tensor-wise` or `channel-wise`.' % quantize_granularity) qsym, calib_layer = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, excluded_operators=excluded_op_names, offline_params=list( arg_params.keys()), quantized_dtype=quantized_dtype, - quantize_mode=quantize_mode) + quantize_mode=quantize_mode, + quantize_granularity=quantize_granularity) th_dict = {} collector = None @@ -801,7 +823,7 @@ def calib_graph(qsym, arg_params, aux_params, collector, return qsym, qarg_params, aux_params -def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', +def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', quantize_granularity='tensor-wise', exclude_layers=None, exclude_layers_match=None, exclude_operators=None, calib_data=None, data_shapes=None, calib_mode='none', num_calib_examples=None, ctx=cpu(), LayerOutputCollector=None, logger=None): @@ -821,6 +843,9 @@ def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', The mode that quantization pass to apply. Support 'full' and 'smart'. 'full' means quantize all operator if possible. 'smart' means quantization pass will smartly choice which operator should be quantized. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. exclude_layers : list of strings A list of strings representing the names of the symbols that users want to excluding exclude_layers_match : list of strings @@ -927,7 +952,8 @@ def __exit__(self, exc_type, exc_value, traceback): sym=symnet, arg_params=args, aux_params=auxs, ctx=ctx, excluded_sym_names=exclude_layers, excluded_op_names=exclude_operators, calib_mode=calib_mode, quantized_dtype=quantized_dtype, quantize_mode=quantize_mode, - LayerOutputCollector=LayerOutputCollector, logger=logger) + quantize_granularity=quantize_granularity, LayerOutputCollector=LayerOutputCollector, + logger=logger) if calib_mode is not None and calib_mode != 'none': if not isinstance(ctx, Context): @@ -987,7 +1013,9 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', """ warnings.warn('WARNING: This will be deprecated after MXNet 2.0, please use quantize_net_v2.') return quantize_net_v2(network=network, quantized_dtype=quantized_dtype, - quantize_mode=quantize_mode, exclude_layers=exclude_layers, + quantize_mode=quantize_mode, + quantize_granularity='tensor-wise', + exclude_layers=exclude_layers, exclude_layers_match=exclude_layers_match, exclude_operators=exclude_operators, calib_data=calib_data, data_shapes=data_shapes, diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 06d4b3e93ba0..6b923a80ff0b 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -925,6 +925,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const char *quantized_dtype, const bool calib_quantize, const char *quantize_mode, + const char *quantize_granularity, mx_uint* out_num_calib_names, const char ***out_calib_names) { nnvm::Symbol *s = new nnvm::Symbol(); @@ -946,12 +947,14 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, } std::string quantized_type(quantized_dtype); std::string quantized_mode(quantize_mode); + std::string quantized_granularity(quantize_granularity); g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["excluded_ops"] = std::make_shared(std::move(excluded_op)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); g.attrs["target_ctx"] = std::make_shared(target_dev); g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); + g.attrs["quantize_granularity"] = std::make_shared(std::move(quantized_granularity)); g = ApplyPass(std::move(g), "QuantizeGraph"); const auto& calib_nodes = g.GetAttr>("calib_nodes"); MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h index 7d64cf5a92a7..1c9396e890f3 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h @@ -43,6 +43,7 @@ struct MKLDNNFCParam: public dmlc::Parameter { bool with_eltwise; dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset + dmlc::optional channel_wise_quantize; DMLC_DECLARE_PARAMETER(MKLDNNFCParam) { DMLC_DECLARE_FIELD(quantized).set_default(false) @@ -61,6 +62,9 @@ struct MKLDNNFCParam: public dmlc::Parameter { .describe("The maximum scalar value in the form of float32 obtained " "through calibration. If present, it will be used to by " "quantized fullyconnected op to calculate primitive scale"); + DMLC_DECLARE_FIELD(channel_wise_quantize) + .set_default(dmlc::optional()) + .describe("Whether support channel-wise-quantize for weight."); } }; @@ -68,8 +72,7 @@ struct MKLDNNFCFullParam { FullyConnectedParam default_param; MKLDNNFCParam mkldnn_param; MKLDNNPostEltwiseParam eltwise_param; - std::vector output_scales = {0.0}; - std::vector requantize_scales = {0.0}; + std::vector output_scales = {0.0f}; }; mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 8c401a879f15..85260940b772 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -37,7 +37,8 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( const NDArray &data, const NDArray &weight, const NDArray *bias, const mkldnn::memory::desc &out_md) { auto data_md = GetMemDesc(data); - auto weight_md = GetFCWeightDesc(weight); + auto weight_md = full_param.mkldnn_param.quantized ? + GetFCWeightDesc(weight, mshadow::kInt8) : GetFCWeightDesc(weight); auto engine = CpuEngine::Get()->get_engine(); auto propagation = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; @@ -52,22 +53,9 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( } attr.set_post_ops(ops); - if (full_param.mkldnn_param.quantized) { - if ((full_param.mkldnn_param.min_calib_range.has_value() && - full_param.mkldnn_param.max_calib_range.has_value()) || - full_param.mkldnn_param.enable_float_output) { - int mask = 0; - std::vector scales = {0.0}; - if (full_param.requantize_scales.size()) { - scales[0] = full_param.requantize_scales[0]; - } else if (full_param.output_scales.size()) { - scales[0] = full_param.output_scales[0]; - } else { - LOG(FATAL) << "Must specified either output_scales or requantize_scales!"; - } - - attr.set_output_scales(mask, scales); - } + if (full_param.mkldnn_param.quantized && full_param.output_scales.size()) { + int mask = (full_param.output_scales.size() == 1) ? 0 : (1 << 1); + attr.set_output_scales(mask, full_param.output_scales); } auto GetFCFwdPd = [&full_param, &attr, @@ -88,7 +76,8 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( if (bias) { if ((*bias).shape().ndim() != 1) LOG(FATAL) << "Unexpected shape for bias " << (*bias).shape(); - auto bias_md = GetMemDesc(*bias); + auto bias_md = + full_param.mkldnn_param.quantized ? GetMemDesc(*bias, mshadow::kInt32) : GetMemDesc(*bias); mkldnn::inner_product_forward::desc desc(propagation, data_md, weight_md, bias_md, out_md); return GetFCFwdPd(desc); diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 01365067ce93..229793fad6a5 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -125,7 +125,8 @@ inline QuantizeType NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes, const std::unordered_set& excluded_ops, const int& dev_type, - std::unordered_map* quantized_node_map) { + std::unordered_map* quantized_node_map, + const std::string quantize_granularity) { std::unordered_map quantized_node; static auto& quantizable_map = Op::GetAttr("FQuantizable"); static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); @@ -165,6 +166,10 @@ inline QuantizeType NeedQuantize(NodePtr node, auto quantized_node = quantized_op_map[op](node->attrs); if (!quantized_node->op()) need = false; if (need) { + if ((quantize_granularity == "channel-wise") && + (node->op() == Op::Get("_sg_mkldnn_fully_connected"))) { + quantized_node->attrs.dict["channel_wise_quantize"] = "True"; + } quantized_node_map->insert(std::make_pair(node, quantized_node)); } if (quantizable_map.count(op)) { @@ -189,6 +194,7 @@ static void MarkQuantizedNodes(const Graph& src, const auto excluded_ops = src.GetAttr>("excluded_ops"); const auto quantize_mode = src.GetAttr("quantize_mode"); const auto dev_type = src.GetAttr("target_ctx"); + const auto quantize_granularity = src.GetAttr("quantize_granularity"); std::unordered_map> node_output_map; std::unordered_set must_quantize_nodes; @@ -196,7 +202,8 @@ static void MarkQuantizedNodes(const Graph& src, // Build node_output_map, must_quantize_nodes and support_quantize_nodes; DFSVisit(src.outputs, [&](const NodePtr& node) { auto quantize_type = - NeedQuantize(node, excluded_nodes, excluded_ops, dev_type, quantized_node_map); + NeedQuantize(node, excluded_nodes, excluded_ops, dev_type, + quantized_node_map, quantize_granularity); if (quantize_type == QuantizeType::kMust) { must_quantize_nodes.insert(node); } else if (quantize_type == QuantizeType::kSupport) { @@ -265,6 +272,13 @@ Graph QuantizeGraph(Graph &&src) { Op::GetAttr("FAvoidQuantizeInput"); const auto offline_params = src.GetAttr>("offline_params"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); + const auto quantize_granularity = src.GetAttr("quantize_granularity"); + const auto dev_type = src.GetAttr("target_ctx"); + + if (dev_type == Context::kGPU && quantize_granularity == "channel-wise") { + LOG(FATAL) << "`channel-wise` quantization option is not supported yet by GPU," + << " please set quantize_granularity to `tensor-wise` when quantizing model."; + } std::unordered_map quantized_node_map; MarkQuantizedNodes(src, &quantized_node_map); @@ -298,7 +312,7 @@ Graph QuantizeGraph(Graph &&src) { // e's source node and the newly created quantize op so that the quantize op can be // reused next time when the same entry is visited again. if (avoid_quantize_input_map.count(node->op()) && - avoid_quantize_input_map[node->op()](node->attrs, i)) { + avoid_quantize_input_map[node->op()](node->attrs, i, quantize_granularity)) { new_node->inputs.emplace_back(mirror_entry); } else if (!quantized_node_map.count(e.node)) { if (mirror_entry_map.count(e)) { @@ -349,7 +363,7 @@ Graph QuantizeGraph(Graph &&src) { uint32_t min_index = 1; uint32_t max_index = 2; if (avoid_quantize_input_map.count(node->op()) && - avoid_quantize_input_map[node->op()](node->attrs, i)) { + avoid_quantize_input_map[node->op()](node->attrs, i, quantize_granularity)) { // skip non-quantized input continue; } diff --git a/src/operator/quantization/quantized_batch_norm.cc b/src/operator/quantization/quantized_batch_norm.cc index 285c34d3a67f..91baf4303971 100644 --- a/src/operator/quantization/quantized_batch_norm.cc +++ b/src/operator/quantization/quantized_batch_norm.cc @@ -135,11 +135,9 @@ NNVM_REGISTER_OP(BatchNorm) } return node; }) -.set_attr("FAvoidQuantizeInput", [](const NodeAttrs &attrs, size_t index) { - if (index == 0) - return false; - else - return true; +.set_attr("FAvoidQuantizeInput", []( + const NodeAttrs &attrs, const size_t index, const std::string quantize_granularity) { + return (index != 0); }); } // namespace op diff --git a/src/operator/quantization/quantized_indexing_op.cc b/src/operator/quantization/quantized_indexing_op.cc index b4af3ecb704f..66f6936d79fd 100644 --- a/src/operator/quantization/quantized_indexing_op.cc +++ b/src/operator/quantization/quantized_indexing_op.cc @@ -181,11 +181,9 @@ NNVM_REGISTER_OP(Embedding) } return node; }) -.set_attr("FAvoidQuantizeInput", [](const NodeAttrs &attrs, size_t index) { - if (index == 0) - return true; - else - return false; +.set_attr("FAvoidQuantizeInput", []( + const NodeAttrs &attrs, const size_t index, const std::string quantize_granularity) { + return (index == 0); }); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_common.h b/src/operator/subgraph/mkldnn/mkldnn_common.h new file mode 100644 index 000000000000..87ddc438d846 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_common.h @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_common.h + * \brief Common header file for MKLDNN backend subgraph + * \author Ciyong Chen +*/ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_ +#if MXNET_USE_MKLDNN == 1 +#include + +namespace mxnet { +namespace op { + +template +static std::vector GetWeightScales(const NDArray &weight, const NDArray *bias, + const float data_scale, bool weight_channelwise_scale) { + auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + std::vector weight_scales; + const DType *weight_ptr = weight.data().dptr(); + const DType *bias_ptr = bias? bias->data().dptr() : nullptr; + const auto wshape = weight.shape(); + size_t channel = wshape[0]; + + size_t offset = wshape.ProdShape(1, wshape.ndim()); + std::vector weight_c_min(channel, MaxValue()); + std::vector weight_c_max(channel, MinValue()); + for (int c = 0; c < static_cast(channel); ++c) { + const DType *p1 = weight_ptr + c * offset; + for (size_t k = 0; k < offset; ++k) { + if (weight_c_min[c] > p1[k]) + weight_c_min[c] = p1[k]; + if (weight_c_max[c] < p1[k]) + weight_c_max[c] = p1[k]; + } + } + + if (weight_channelwise_scale) { + weight_scales.resize(channel); + #pragma omp parallel for num_threads(nthreads) + for (int c = 0; c < static_cast(channel); ++c) { + float scale = GetQuantizeScale(mshadow::kInt8, weight_c_min[c], weight_c_max[c]); + if (bias_ptr && bias_ptr[c]) { + // avoid overflow on bias + // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value of bias + // to INT_MAX / 2. + float scale_max = + static_cast(bias_ptr[c] > 0 ? MaxValue() : MinValue()) / 2 / + bias_ptr[c] / data_scale; + scale = Min(scale, scale_max); + } + weight_scales[c] = scale; + } + } else { + DType total_min = weight_c_min[0]; + DType total_max = weight_c_max[0]; + for (size_t c = 0; c < channel; ++c) { + if (total_min > weight_c_min[c]) total_min = weight_c_min[c]; + if (total_max < weight_c_max[c]) total_max = weight_c_max[c]; + } + weight_scales.resize(3); + weight_scales[0] = GetQuantizeScale(mshadow::kInt8, total_min, total_max); + weight_scales[1] = total_min; + weight_scales[2] = total_max; + } + return weight_scales; +} + +static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias, + const mkldnn::memory::desc weight_md, + const mkldnn::memory::desc *bias_md, + const int num_group, float data_scale, + const std::vector &weight_scales, + const bool submit = true) { + MKLDNNStream *stream = MKLDNNStream::Get(); + const auto new_weight = NDArray(weight_md); + const auto conv_weights_memory = new_weight.GetMKLDNNData(); + mkldnn::primitive_attr weight_attr; + if (weight_scales.size()) { + const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1; + weight_attr.set_output_scales(weight_mask, weight_scales); + } + auto default_weights_memory = GetWeights(*weight, num_group); + if (default_weights_memory == nullptr) default_weights_memory = weight->GetMKLDNNData(); + const auto weight_reorder_pd = + mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(weight_reorder_pd), + {{MKLDNN_ARG_FROM, *default_weights_memory}, {MKLDNN_ARG_TO, *conv_weights_memory}}); + NDArray new_bias; + if (has_bias && data_scale) { + std::vector bias_scales(weight_scales.size()); + for (size_t c = 0; c < weight_scales.size(); ++c) { + bias_scales[c] = weight_scales[c] * data_scale; + } + new_bias = NDArray(*bias_md); + const auto conv_bias_memory = new_bias.GetMKLDNNData(); + const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1; + mkldnn::primitive_attr bias_attr; + bias_attr.set_output_scales(bias_mask, bias_scales); + auto bias_weights_memory = bias->GetMKLDNNData(); + const auto bias_reorder_pd = + mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(bias_reorder_pd), + {{MKLDNN_ARG_FROM, *bias_weights_memory}, {MKLDNN_ARG_TO, *conv_bias_memory}}); + } + if (submit) + stream->Submit(); + *weight = new_weight; + if (has_bias && data_scale) *bias = new_bias; +} + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index f3a7d2c4e914..df440222cf04 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -29,6 +29,7 @@ #include "mkldnn_conv-inl.h" #include "../../nn/mkldnn/mkldnn_act-inl.h" #include "../../tensor/matrix_op-inl.h" +#include "mkldnn_common.h" namespace mxnet { namespace op { @@ -81,101 +82,6 @@ static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { (param.full_conv_param.mkldnn_param.with_bn ? 4 : 0); } -template -static std::vector GetWeightScales(const NDArray &weight, const NDArray *bias, - const float data_scale, bool weight_channelwise_scale) { - std::vector weight_scales; - const DType *weight_ptr = weight.data().dptr(); - const DType *bias_ptr = bias? bias->data().dptr() : nullptr; - size_t channel = weight.shape()[0]; - - // TODO(Zhennan): Handle the case weight is not in dims 4. - size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3]; - std::vector weight_c_min(channel, MaxValue()); - std::vector weight_c_max(channel, MinValue()); - for (int c = 0; c < static_cast(channel); ++c) { - const DType *p1 = weight_ptr + c * offset; - for (size_t k = 0; k < offset; ++k) { - if (weight_c_min[c] > p1[k]) - weight_c_min[c] = p1[k]; - if (weight_c_max[c] < p1[k]) - weight_c_max[c] = p1[k]; - } - } - - if (weight_channelwise_scale) { - weight_scales.resize(channel); -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int c = 0; c < static_cast(channel); ++c) { - float scale = GetQuantizeScale(mshadow::kInt8, weight_c_min[c], weight_c_max[c]); - if (bias_ptr) { - // avoid overflow on bias - // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value of bias - // to INT_MAX / 2. - float scale_max = - static_cast(bias_ptr[c] > 0 ? MaxValue() : MinValue()) / 2 / - bias_ptr[c] / data_scale; - scale = Min(scale, scale_max); - } - weight_scales[c] = scale; - } - } else { - DType total_min = weight_c_min[0]; - DType total_max = weight_c_max[0]; - for (size_t c = 0; c < channel; ++c) { - if (total_min > weight_c_min[c]) total_min = weight_c_min[c]; - if (total_max < weight_c_max[c]) total_max = weight_c_max[c]; - } - weight_scales.resize(3); - weight_scales[0] = GetQuantizeScale(mshadow::kInt8, total_min, total_max); - weight_scales[1] = total_min; - weight_scales[2] = total_max; - } - return weight_scales; -} - -static void ConvertWeightBias2MKLDNN(const MKLDNNConvFullParam ¶m, - mkldnn::convolution_forward::primitive_desc fwd_pd, - NDArray *weight, NDArray *bias, bool has_bias, - float data_scale, const std::vector &weight_scales) { - MKLDNNStream *stream = MKLDNNStream::Get(); - const auto new_weight = NDArray(fwd_pd.weights_desc()); - const auto conv_weights_memory = new_weight.GetMKLDNNData(); - mkldnn::primitive_attr weight_attr; - if (weight_scales.size()) { - const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1; - weight_attr.set_output_scales(weight_mask, weight_scales); - } - auto default_weights_memory = GetWeights(*weight, param.conv_param.num_group); - if (default_weights_memory == nullptr) default_weights_memory = weight->GetMKLDNNData(); - const auto weight_reorder_pd = - mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr); - MKLDNNStream::Get()->RegisterPrimArgs( - mkldnn::reorder(weight_reorder_pd), - {{MKLDNN_ARG_FROM, *default_weights_memory}, {MKLDNN_ARG_TO, *conv_weights_memory}}); - NDArray new_bias; - if (has_bias && data_scale) { - std::vector bias_scales(weight_scales.size()); - for (size_t c = 0; c < weight_scales.size(); ++c) { - bias_scales[c] = weight_scales[c] * data_scale; - } - new_bias = NDArray(fwd_pd.bias_desc()); - const auto conv_bias_memory = new_bias.GetMKLDNNData(); - const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1; - mkldnn::primitive_attr bias_attr; - bias_attr.set_output_scales(bias_mask, bias_scales); - auto bias_weights_memory = bias->GetMKLDNNData(); - const auto bias_reorder_pd = - mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr); - MKLDNNStream::Get()->RegisterPrimArgs( - mkldnn::reorder(bias_reorder_pd), - {{MKLDNN_ARG_FROM, *bias_weights_memory}, {MKLDNN_ARG_TO, *conv_bias_memory}}); - } - stream->Submit(); - *weight = new_weight; - if (has_bias && data_scale) *bias = new_bias; -} - class SgMKLDNNConvOperator { public: explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs) @@ -396,8 +302,13 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, fwd_.reset(new MKLDNNConvForward( full_conv_param, ctx.is_train, data, cached_weight_, has_bias ? &cached_bias_ : nullptr, output)); - ConvertWeightBias2MKLDNN(full_conv_param, fwd_->GetPd(), &cached_weight_, &cached_bias_, - has_bias, data_scale_, weight_scales_); + mkldnn::memory::desc bias_md; + if (has_bias) bias_md = fwd_->GetPd().bias_desc(); + ConvertWeightBias2MKLDNN(&cached_weight_, &cached_bias_, has_bias, + fwd_->GetPd().weights_desc(), + has_bias ? & bias_md : nullptr, + full_conv_param.conv_param.num_group, + data_scale_, weight_scales_); args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData(); if (has_bias) args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData(); @@ -751,7 +662,8 @@ nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { return node; } -bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, size_t index) { +bool SgMKLDNNAvoidConvQuantizeInput(const NodeAttrs &attrs, const size_t index, + const std::string quantize_granularity) { auto const ¶m = nnvm::get(attrs.parsed); std::unordered_set avoid_indice; size_t idx = 0; @@ -800,7 +712,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) }) .set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) -.set_attr("FAvoidQuantizeInput", SgMKLDNNAvoidQuantizeInput); +.set_attr("FAvoidQuantizeInput", SgMKLDNNAvoidConvQuantizeInput); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index 349d18f4a3a2..4d5233d3881f 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -37,6 +37,7 @@ #include "../../tensor/matrix_op-inl.h" #include "../../quantization/quantization_utils.h" #include "mkldnn_fc-inl.h" +#include "mkldnn_common.h" namespace mxnet { namespace op { @@ -44,8 +45,7 @@ namespace op { class SgMKLDNNFCOp { public: explicit SgMKLDNNFCOp(const nnvm::NodeAttrs &attrs) - : initialized_(false), - subgraph_sym_(*attrs.subgraphs[0]), + : subgraph_sym_(*attrs.subgraphs[0]), full_param_(nnvm::get(attrs.parsed)) {} void Forward(const OpContext &ctx, @@ -62,11 +62,13 @@ class SgMKLDNNFCOp { } private: - bool initialized_; + bool initialized_{false}; nnvm::Symbol subgraph_sym_; MKLDNNFCFullParam full_param_; + mkldnn_args_map_t args_; std::shared_ptr fwd_; - std::shared_ptr cached_weight_; + std::shared_ptr cached_out_mem_; + NDArray cached_weight_; NDArray cached_bias_; float cached_min_data_; float cached_max_data_; @@ -74,10 +76,34 @@ class SgMKLDNNFCOp { float cached_max_weight_; float cached_min_bias_; float cached_max_bias_; + size_t weight_ver_; + size_t bias_ver_; float cached_min_output_; float cached_max_output_; + float data_scale_{0.0f}; + std::vector weight_scales_; }; +static inline void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, + NDArray *in_data) { + const mxnet::TShape ishape = in_data->shape(); + + // If the input data is a view of an MKLDNN array, we should create a new + // NDArray with reordered data. + if (in_data->IsMKLDNNData() && in_data->IsView()) + *in_data = in_data->Reorder2Default(); + + auto data_ndim = ishape.ndim(); + if (data_ndim != 2) { + if (!param.flatten) { + *in_data = in_data->MKLDNNDataReshape( + Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1])); + } else { + *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim))); + } + } +} + void SgMKLDNNFCOp::Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -90,23 +116,33 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, size_t base_num_outputs = 1; size_t total_num_outputs = base_num_outputs; - float min_data = 0.0; - float max_data = 0.0; - float min_weight = 0.0; - float max_weight = 0.0; - float min_bias = 0.0; - float max_bias = 0.0; + float min_data = 0.0f; + float max_data = 0.0f; + float min_weight = 0.0f; + float max_weight = 0.0f; + float min_bias = 0.0f; + float max_bias = 0.0f; + + bool channel_wise = false; + if (mkldnn_param.channel_wise_quantize.has_value() && + mkldnn_param.channel_wise_quantize) { + channel_wise = true; + } if (mkldnn_param.quantized) { - total_num_inputs = base_num_inputs * 3; + if (channel_wise) { + total_num_inputs = base_num_inputs + 2; + } else { + total_num_inputs = base_num_inputs * 3; + min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr()[0]; + max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr()[0]; + if (has_bias) { + min_bias = in_data[base_num_inputs + quantized_fullc::kBiasMin].data().dptr()[0]; + max_bias = in_data[base_num_inputs + quantized_fullc::kBiasMax].data().dptr()[0]; + } + } min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr()[0]; max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr()[0]; - min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr()[0]; - max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr()[0]; - if (has_bias) { - min_bias = in_data[base_num_inputs + quantized_fullc::kBiasMin].data().dptr()[0]; - max_bias = in_data[base_num_inputs + quantized_fullc::kBiasMax].data().dptr()[0]; - } if (!mkldnn_param.enable_float_output) { total_num_outputs = base_num_outputs * 3; } @@ -115,84 +151,157 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, CHECK_EQ(out_data.size(), total_num_outputs); NDArray data = in_data[fullc::kData]; - NDArray weight = cached_weight_ ? *cached_weight_ : in_data[fullc::kWeight]; + NDArray weight = in_data[fullc::kWeight]; NDArray output = out_data[fullc::kOut]; - - mkldnn::memory::desc out_md = GetMemDesc(output); - MKLDNNFCFlattenData(default_param, out_data[fullc::kOut], &data, &out_md); + MKLDNNFCFlattenData(default_param, &data); if (initialized_ && mkldnn_param.quantized) { - if (cached_min_data_ != min_data || cached_max_data_ != max_data || - cached_min_weight_ != min_weight || cached_max_weight_ != max_weight || - (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) { - initialized_ = false; - } + if (channel_wise) { + if (cached_min_data_ != min_data || cached_max_data_ != max_data || + weight_ver_ != weight.version() || + (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) { + initialized_ = false; + } + } else { + if (cached_min_data_ != min_data || cached_max_data_ != max_data || + cached_min_weight_ != min_weight || cached_max_weight_ != max_weight || + (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) { + initialized_ = false; + } + } } if (!initialized_) { + const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); cached_min_data_ = min_data; cached_max_data_ = max_data; cached_min_weight_ = min_weight; cached_max_weight_ = max_weight; + weight_ver_ = weight.version(); + cached_weight_ = weight; if (has_bias) { - cached_bias_ = in_data[fullc::kBias]; cached_min_bias_ = min_bias; cached_max_bias_ = max_bias; + bias_ver_ = in_data[fullc::kBias].version(); + cached_bias_ = in_data[fullc::kBias]; } else { cached_bias_ = NDArray(); } + // create cached out_md + const mxnet::TShape ishape = data.shape(); + const mxnet::TShape oshape = output.shape(); + mkldnn::memory::dims out_dims(2); + if (oshape.ndim() == 2) { + out_dims[0] = static_cast(oshape[0]); + out_dims[1] = static_cast(oshape[1]); + } else { + if (!default_param.flatten) { + out_dims[0] = static_cast(oshape.ProdShape(0, oshape.ndim()-1)); + out_dims[1] = static_cast(oshape[oshape.ndim()-1]); + } else { + out_dims[0] = static_cast(static_cast(oshape[0])); + out_dims[1] = static_cast(oshape.ProdShape(1, oshape.ndim())); + } + } + mkldnn::memory::desc out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(output.dtype()), + static_cast(GetDefaultFormat(2))); + cached_out_mem_ = std::make_shared(out_md, CpuEngine::Get()->get_engine()); + + bool support_channelwise_scale = false; if (mkldnn_param.quantized) { CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8); - float data_scale = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_); - float weight_scale = GetQuantizeScale(mshadow::kInt8, cached_min_weight_, cached_max_weight_); - if (has_bias) { - NDArray bias = in_data[fullc::kBias]; - float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_); - float bias_int32_rescale = data_scale * weight_scale / bias_scale; - // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value of bias - // to INT_MAX / 2. - float bias_max_rescale = - MaxValue() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale; - if (bias_int32_rescale > bias_max_rescale) { - // avoid overflow on bias - bias_int32_rescale = bias_max_rescale; - float weight_rescale = bias_int32_rescale * bias_scale / data_scale / weight_scale; - cached_weight_.reset(new NDArray(weight.storage_type(), weight.shape(), weight.ctx(), - true, mshadow::kInt8)); - int8_t *weight_ptr = weight.data().dptr(); - int8_t *quantized_weight_ptr = cached_weight_->data().dptr(); - size_t weight_size = weight.shape().Size(); -#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (index_t i = 0; i < static_cast(weight_size); ++i) { - quantized_weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale); + data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_); + + bool fuse_requantize = false; + // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize). + if (mkldnn_param.min_calib_range.has_value() && + mkldnn_param.max_calib_range.has_value()) { + cached_min_output_ = mkldnn_param.min_calib_range.value(); + cached_max_output_ = mkldnn_param.max_calib_range.value(); + support_channelwise_scale = true; + fuse_requantize = true; + } + if (mkldnn_param.enable_float_output) { + support_channelwise_scale = true; + } + // channel_wise support_channelwise_scale result + // True True True + // True False Error + // False True/False False + if (channel_wise && !support_channelwise_scale) { + LOG(FATAL) + << "Currently, channel-wise quantization requires fuse requantize or dequantize." + << " Please make sure the `min_calib_range` and `max_calib_range` are set when only" + << " fuse requantize (outputs of FullyConnected are collected during calibration phase)," + << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and " + << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)"; + } + support_channelwise_scale = support_channelwise_scale && channel_wise; + + if (support_channelwise_scale) { + MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { + weight_scales_ = + GetWeightScales(cached_weight_, has_bias ? &cached_bias_ : nullptr, + data_scale_, support_channelwise_scale); + }); + } else { + weight_scales_.resize(1); + weight_scales_[0] = + GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, cached_max_weight_); + if (has_bias) { + float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_); + float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale; + // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value + // of bias to INT_MAX / 2. + float bias_max_rescale = + MaxValue() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale; + if (bias_int32_rescale > bias_max_rescale) { + // avoid overflow on bias + bias_int32_rescale = bias_max_rescale; + float weight_rescale = + bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0]; + int8_t *weight_ptr = weight.data().dptr(); + size_t weight_size = weight.shape().Size(); + #pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(weight_size); ++i) { + weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale); + } + weight_scales_[0] *= weight_rescale; + } + NDArray bias = in_data[fullc::kBias]; + cached_bias_ = + NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32); + int8_t *bias_ptr = bias.data().dptr(); + int32_t *quantized_bias_ptr = cached_bias_.data().dptr(); + size_t bias_size = bias.shape().Size(); + #pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(bias_size); ++i) { + quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale); } - weight_scale *= weight_rescale; - weight = *cached_weight_; - } - cached_bias_ = - NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32); - int8_t *bias_ptr = bias.data().dptr(); - int32_t *quantized_bias_ptr = cached_bias_.data().dptr(); - size_t bias_size = bias.shape().Size(); - #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (index_t i = 0; i < static_cast(bias_size); ++i) { - quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale); } } - if (mkldnn_param.enable_float_output) { - full_param_.output_scales[0] = 1.0 / data_scale / weight_scale; - full_param_.requantize_scales.resize(0); - } else if (mkldnn_param.min_calib_range.has_value() && - mkldnn_param.max_calib_range.has_value()) { - full_param_.output_scales.resize(0); - cached_min_output_ = mkldnn_param.min_calib_range.value(); - cached_max_output_ = mkldnn_param.max_calib_range.value(); - float out_scale = - GetQuantizeScale(IsOutputUint8(full_param_) ? mshadow::kUint8 : mshadow::kInt8, - cached_min_output_, cached_max_output_); - full_param_.requantize_scales[0] = out_scale / data_scale / weight_scale; + size_t num_channel = cached_weight_.shape()[0]; + if (fuse_requantize || mkldnn_param.enable_float_output) { + float tmp_scale_ = 1.0f; + if (fuse_requantize) { + tmp_scale_ = + GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_) / data_scale_; + } else { + tmp_scale_ = 1.0 / data_scale_; + } + + if (support_channelwise_scale) { + full_param_.output_scales.resize(num_channel); + #pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(num_channel); ++i) { + full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i]; + } + } else { + full_param_.output_scales.resize(1); + full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0]; + } } else { Stream *s = ctx.get_stream(); if (data.dtype() == mshadow::kInt8) { @@ -204,21 +313,48 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_weight, &max_weight); } + full_param_.output_scales.resize(0); } } - fwd_.reset(new MKLDNNFullyConnectedForward(full_param_, ctx.is_train, data, weight, + fwd_.reset(new MKLDNNFullyConnectedForward(full_param_, ctx.is_train, data, cached_weight_, (has_bias ? &cached_bias_ : nullptr), out_md)); + + // convert weight and bias to the format that MKL-DNN requires + if (!mkldnn_param.quantized || support_channelwise_scale) { + mkldnn::memory::desc bias_md; + if (has_bias) bias_md = fwd_->fwd_pd.bias_desc(); + ConvertWeightBias2MKLDNN(&cached_weight_, &cached_bias_, has_bias, + fwd_->fwd_pd.weights_desc(), + has_bias ? &bias_md : nullptr, + 1, data_scale_, weight_scales_, false); + } else { + cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc()); + auto cached_weight_mem = cached_weight_.GetMKLDNNData(); + auto def_weight_mem = weight.GetMKLDNNData(); + std::unordered_map args( + {{MKLDNN_ARG_FROM, *def_weight_mem}, + {MKLDNN_ARG_TO, *cached_weight_mem}}); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args); + } + + args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData(); + if (has_bias) + args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData(); + args_[MKLDNN_ARG_DST] = *cached_out_mem_; initialized_ = true; } - std::vector new_inputs; - if (has_bias) { - new_inputs = {data, weight, cached_bias_}; - } else { - new_inputs = {data, weight}; - } - MKLDNNFCForwardFullFeature(full_param_, ctx, fwd_.get(), new_inputs, req, out_data); + auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_desc()); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + cached_out_mem_->set_data_handle(reinterpret_cast(output.data().dptr())); + }); + args_[MKLDNN_ARG_SRC] = *data_mem; + args_[MKLDNN_ARG_DST] = *cached_out_mem_; + MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_); + MKLDNNStream::Get()->Submit(); if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) { float *min_output_ptr = out_data[quantized_fullc::kOutMin].data().dptr(); @@ -277,13 +413,20 @@ static std::vector SgMKLDNNFCListInputNames(const NodeAttrs &attrs) auto const &full_param = nnvm::get(attrs.parsed); std::vector input_names = DefaultSubgraphOpListInputs(attrs); if (full_param.mkldnn_param.quantized) { + bool channel_wise = false; + if (full_param.mkldnn_param.channel_wise_quantize.has_value() && + full_param.mkldnn_param.channel_wise_quantize) { + channel_wise = true; + } input_names.emplace_back("min_data"); input_names.emplace_back("max_data"); - input_names.emplace_back("min_weight"); - input_names.emplace_back("max_weight"); - if (!full_param.default_param.no_bias) { - input_names.emplace_back("min_bias"); - input_names.emplace_back("max_bias"); + if (!channel_wise) { + input_names.emplace_back("min_weight"); + input_names.emplace_back("max_weight"); + if (!full_param.default_param.no_bias) { + input_names.emplace_back("min_bias"); + input_names.emplace_back("max_bias"); + } } } return input_names; @@ -349,17 +492,25 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs, std::vector *out_types) { auto const &full_param = nnvm::get(attrs.parsed); if (full_param.mkldnn_param.quantized) { + bool channel_wise = false; + if (full_param.mkldnn_param.channel_wise_quantize.has_value() && + full_param.mkldnn_param.channel_wise_quantize) { + channel_wise = true; + } size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3; - CHECK(in_types->at(0) == mshadow::kInt8 || in_types->at(0) == mshadow::kUint8) << "QuantizedFullyConnected only supports int8/uint8 input, while " << in_types->at(0) << " is given."; for (size_t i = 1; i < in_types->size(); ++i) { - if (i < base_num_inputs) { - TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8); - } else { + if (channel_wise) { TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } else { + if (i < base_num_inputs) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8); + } else { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } } } @@ -448,15 +599,35 @@ nnvm::NodePtr SgMKLDNNFCQuantizedOp(const NodeAttrs& attrs) { return node; } +static bool SgMKLDNNAvoidFCQuantizeInput(const NodeAttrs& attrs, const size_t index_to_check, + const std::string quantize_granularity) { + auto const &full_param = nnvm::get(attrs.parsed); + std::unordered_set avoid_indexes; + if (quantize_granularity == "channel-wise") { + avoid_indexes.insert(fullc::kWeight); // weight + if (!full_param.default_param.no_bias) { + avoid_indexes.insert(fullc::kBias); // bias + } + } + + return avoid_indexes.count(index_to_check); +} + NNVM_REGISTER_OP(_sg_mkldnn_fully_connected) .describe(R"code(_sg_mkldnn_fully_connected)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { auto const &full_param = nnvm::get(attrs.parsed); auto num_inputs = full_param.default_param.no_bias ? 2 : 3; - if (full_param.mkldnn_param.quantized) - return num_inputs * 3; - else + if (full_param.mkldnn_param.quantized) { + if (full_param.mkldnn_param.channel_wise_quantize.has_value() && + full_param.mkldnn_param.channel_wise_quantize) { + return num_inputs + 2; // min_data, max_data + } else { + return num_inputs * 3; + } + } else { return num_inputs; + } }) .set_num_outputs([](const NodeAttrs& attrs) { auto const &full_param = nnvm::get(attrs.parsed); @@ -485,7 +656,8 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected) return QuantizeType::kMust; }) .set_attr("FQuantizedOp", SgMKLDNNFCQuantizedOp) -.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }); +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.set_attr("FAvoidQuantizeInput", SgMKLDNNAvoidFCQuantizeInput); } // namespace op } // namespace mxnet diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 4a2aedf58281..65b73e438ea6 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -131,6 +131,10 @@ def __iter__(self): def check_quantize(sym, data_shape, out_type, name='conv', check_calibration=True, gluon_forward=False, check_scale_align=False): + quantize_granularity_list = ['tensor-wise'] + if name == 'fc': + quantize_granularity_list += ['channel-wise'] + if name in config: name = config[name][OP_NAME] sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) @@ -158,33 +162,35 @@ def check_quantize(sym, data_shape, out_type, name='conv', calib_data = CalibIter(batch, data_shape, 1) - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, - arg_params=arg_params, - aux_params=aux_params, - ctx=mx.current_context(), - excluded_sym_names=excluded_sym_names, - excluded_op_names=excluded_op_names, - quantized_dtype=out_type, - calib_mode='naive', - calib_data=calib_data, - label_names=None, - num_calib_examples=1, - quantize_mode='full') - qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) - if check_calibration: - check_qsym_calibrated(qsym, out_type, name=name) - if check_scale_align: - check_qsym_scale_align(qsym) - if gluon_forward == True: - check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape) - else: - quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape) - for i in range(len(ref_out)): - min_range = mx.nd.min(ref_out[i]).asscalar() - max_range = mx.nd.max(ref_out[i]).asscalar() - atol = 0.1 * max(abs(min_range), abs(max_range)) - assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) - check_qsym_dummy_forward(qsym, batch, data_shape) + for quantize_granularity in quantize_granularity_list: + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + excluded_sym_names=excluded_sym_names, + excluded_op_names=excluded_op_names, + quantized_dtype=out_type, + calib_mode='naive', + calib_data=calib_data, + label_names=None, + num_calib_examples=1, + quantize_mode='full', + quantize_granularity=quantize_granularity) + qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) + if check_calibration: + check_qsym_calibrated(qsym, out_type, name=name) + if check_scale_align: + check_qsym_scale_align(qsym) + if gluon_forward == True: + check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape) + else: + quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape) + for i in range(len(ref_out)): + min_range = mx.nd.min(ref_out[i]).asscalar() + max_range = mx.nd.max(ref_out[i]).asscalar() + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) + check_qsym_dummy_forward(qsym, batch, data_shape) @with_seed() def check_quantize_whole_model_with_forward(): @@ -240,7 +246,7 @@ def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quan if check_fp32_fusion: data_min = -1.0 data_max = 1.0 - if ''.join(sym.get_internals().list_outputs()).find('sqrt'): + if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1: check_quantization = False data_min = 0 @@ -274,12 +280,12 @@ def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quan if check_quantization: # fp32 to int8 for out_type in out_types: - check_quantize(sym, data_shape, out_type, name=op_name) + check_quantize(sym, data_shape, out_type, name=name) # TODO(ciyong), since quantized fc save its params in int8, while gluon treat the default # variable from symbol file as fp32 which results in mismatch dtype of params. # Skip quantized fc in gluon pass. if name != 'fc': - check_quantize(sym, data_shape, out_type, name=op_name, gluon_forward=True) + check_quantize(sym, data_shape, out_type, name=name, gluon_forward=True) def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10), name='conv'): @@ -767,7 +773,7 @@ def test_pos_conv_bn_sum_act(): "softrelu": True, "relu6": False, "leakyrelu": True, - "gelu": True} + "gelu": False} for data_shape in DATA_SHAPE: for (alg, quantize) in act_list.items(): net, attrs = conv_bn_sum_act(False, data_shape, alg) @@ -847,7 +853,6 @@ def test_single_fc(): else: check_fusion(syms, dshape, attrs, check_quantization=False) - @with_seed() def test_fc_eltwise(): for dshape, no_bias, flatten, alg in itertools.product(DATA_SHAPE,