diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index d807e7f2d19d..3f644fc771a7 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -180,6 +180,7 @@ def save_params(fname, arg_params, aux_params, logger=None): sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) sym = sym.get_backend_symbol('MKLDNN') + sym = sym.get_backend_symbol('MKLDNN_FC') # get batch size batch_size = args.batch_size @@ -207,19 +208,18 @@ def save_params(fname, arg_params, aux_params, logger=None): if args.model == 'imagenet1k-resnet-152': rgb_mean = '0,0,0' rgb_std = '1,1,1' - excluded_sym_names += ['flatten0', 'fc1'] + excluded_sym_names += ['flatten0'] if exclude_first_conv: excluded_sym_names += ['conv0'] elif args.model == 'imagenet1k-inception-bn': rgb_mean = '123.68,116.779,103.939' rgb_std = '1,1,1' - excluded_sym_names += ['flatten', 'fc1'] + excluded_sym_names += ['flatten'] if exclude_first_conv: excluded_sym_names += ['conv_1'] elif args.model in ['resnet50_v1', 'resnet101_v1']: rgb_mean = '123.68,116.779,103.939' rgb_std = '58.393, 57.12, 57.375' - excluded_sym_names += ['resnetv10_dense0_fwd'] if exclude_first_conv: excluded_sym_names += ['resnetv10_conv0_fwd'] elif args.model == 'squeezenet1.0': @@ -232,14 +232,12 @@ def save_params(fname, arg_params, aux_params, logger=None): rgb_mean = '123.68,116.779,103.939' rgb_std = '58.393, 57.12, 57.375' excluded_sym_names += ['mobilenet0_flatten0_flatten0', - 'mobilenet0_dense0_fwd', 'mobilenet0_pool0_fwd'] if exclude_first_conv: excluded_sym_names += ['mobilenet0_conv0_fwd'] elif args.model == 'inceptionv3': rgb_mean = '123.68,116.779,103.939' rgb_std = '58.393, 57.12, 57.375' - excluded_sym_names += ['inception30_dense0_fwd'] if exclude_first_conv: excluded_sym_names += ['inception30_conv0_fwd'] elif args.model == 'custom': @@ -305,6 +303,7 @@ def save_params(fname, arg_params, aux_params, logger=None): % calib_mode) sym_name = '%s-symbol.json' % (prefix + suffix) qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') + qsym = qsym.get_backend_symbol('MKLDNN_POST_FC_QUANTIZE') save_symbol(sym_name, qsym, logger) param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) save_params(param_name, qarg_params, aux_params, logger) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 611592aa4d82..aca7c58707e2 100755 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -159,6 +159,12 @@ def __call__(self, desc, arr): elif desc.endswith('max'): self._init_one(desc, arr) self._verbose_print(desc, 'max', arr) + elif desc.endswith('weight_quantize'): + self._init_quantized_weight(desc, arr) + self._verbose_print(desc, 'weight_quantize', arr) + elif desc.endswith('bias_quantize'): + self._init_quantized_bias(desc, arr) + self._verbose_print(desc, 'bias_quantize', arr) else: self._init_default(desc, arr) @@ -235,6 +241,9 @@ def _init_one(self, _, arr): def _init_bias(self, _, arr): arr[:] = 0.0 + def _init_quantized_bias(self, _, arr): + arr[:] = 0 + def _init_gamma(self, _, arr): arr[:] = 1.0 @@ -245,6 +254,10 @@ def _init_weight(self, name, arr): """Abstract method to Initialize weight.""" raise NotImplementedError("Must override it") + def _init_quantized_weight(self, _, arr): + _arr = random.randint(-127, 127, dtype='int32').asnumpy() + arr[:] = np.int8(_arr) + def _init_default(self, name, _): raise ValueError( 'Unknown initialization pattern for %s. ' \ diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h new file mode 100644 index 000000000000..c08371489fed --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_fully_connected-inl.h + * \brief Common functions used by MKLDNN (Quantized) FullyConnected operator + * \author Ciyong Chen +*/ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include "../fully_connected-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +struct MKLDNNFCParam: public dmlc::Parameter { + bool quantized; + bool enable_float_output; + bool with_relu; + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + + DMLC_DECLARE_PARAMETER(MKLDNNFCParam) { + DMLC_DECLARE_FIELD(quantized).set_default(false) + .describe("Whether it's a quantized FullyConnected operator"); + DMLC_DECLARE_FIELD(enable_float_output).set_default(false) + .describe("Whether to enable float32 output"); + DMLC_DECLARE_FIELD(with_relu).set_default(false) + .describe("Whether there's a post relu after FullyConnected operator"); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized fullyconnected op to calculate primitive scale"); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized fullyconnected op to calculate primitive scale"); + } +}; + +struct MKLDNNFCFullParam { + FullyConnectedParam default_param; + MKLDNNFCParam mkldnn_param; + std::vector output_scales = {0.0}; + std::vector requantize_scales = {0.0}; +}; + +mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( + const MKLDNNFCFullParam &full_param, const bool is_train, + const NDArray &data, const NDArray &weight, const NDArray *bias, + const mkldnn::memory::desc &out_md); + +class MKLDNNFullyConnectedForward { + public: + mkldnn::inner_product_forward::primitive_desc fwd_pd; + + MKLDNNFullyConnectedForward(const MKLDNNFCFullParam &full_param, const bool is_train, + const NDArray &data, const NDArray &weight, + const NDArray *bias, + const mkldnn::memory::desc &out_md) + : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {} + + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory *bias, const mkldnn::memory &output); + + const mkldnn::inner_product_forward &GetFwd() const { + return *fwd_; + } + + private: + std::shared_ptr fwd_; + std::shared_ptr data_; + std::shared_ptr weight_; + std::shared_ptr bias_; + std::shared_ptr out_; +}; + +typedef ParamOpSign MKLDNNFullyconSignature; + +MKLDNNFullyConnectedForward &GetFCFwd( + const FullyConnectedParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weight, + const NDArray *bias, const mkldnn::memory::desc &out_md); + +void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, + const NDArray &out_data, + NDArray *in_data, + mkldnn::memory::desc *out_md); + +void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + +void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam ¶m, + const OpContext &ctx, + MKLDNNFullyConnectedForward *fwd, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 05ef7ebd6573..03d7e62da399 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -18,220 +18,296 @@ */ /*! + * Copyright (c) 2018 by Contributors * \file mkldnn_fully_connected.cc - * \brief - * \author Da Zheng + * \brief MKLDNN FullyConnected operator + * \author Da Zheng, Ciyong Chen */ -#include "../fully_connected-inl.h" -#include "./mkldnn_base-inl.h" - #if MXNET_USE_MKLDNN == 1 +#include "mkldnn_fully_connected-inl.h" + namespace mxnet { namespace op { -inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd( +DMLC_REGISTER_PARAMETER(MKLDNNFCParam); + +mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( + const MKLDNNFCFullParam &full_param, const bool is_train, const NDArray &data, const NDArray &weight, const NDArray *bias, - const mkldnn::memory::desc &out_md, const bool is_train) { + const mkldnn::memory::desc &out_md) { auto data_md = GetMemDesc(data); auto weight_md = GetMemDesc(weight); auto engine = CpuEngine::Get()->get_engine(); auto propagation = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; + + mkldnn::primitive_attr attr; + mkldnn::post_ops ops; + if (full_param.mkldnn_param.with_relu) { + const float scale = 1.0f; + const float alpha = 0.0f; + const float beta = 1.0f; + ops.append_eltwise(scale, eltwise_relu, alpha, beta); + } + attr.set_post_ops(ops); + + if (full_param.mkldnn_param.quantized) { + if ((full_param.mkldnn_param.min_calib_range.has_value() && + full_param.mkldnn_param.max_calib_range.has_value()) || + full_param.mkldnn_param.enable_float_output) { + int mask = 0; + std::vector scales = {0.0}; + if (full_param.requantize_scales.size()) { + scales[0] = full_param.requantize_scales[0]; + } else if (full_param.output_scales.size()) { + scales[0] = full_param.output_scales[0]; + } else { + LOG(FATAL) << "Must specified either output_scales or requantize_scales!"; + } + + attr.set_output_scales(mask, scales); + attr.set_int_output_round_mode(round_nearest); + } + } + + auto GetFCFwdPd = [&full_param, &attr, + &engine](const mkldnn::inner_product_forward::desc &desc) { + try { + return mkldnn::inner_product_forward::primitive_desc(desc, attr, engine); + } catch (mkldnn::error &e) { + if (e.status == mkldnn_unimplemented && + full_param.mkldnn_param.quantized) { + LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 fully_connected."; + } else { + LOG(ERROR) << e.message; + } + throw; + } + }; + if (bias) { auto bias_md = GetMemDesc(*bias); - mkldnn::inner_product_forward::desc ipFwd_desc(propagation, + mkldnn::inner_product_forward::desc desc(propagation, data_md, weight_md, bias_md, out_md); - return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); + return GetFCFwdPd(desc); } else { - mkldnn::inner_product_forward::desc ipFwd_desc(propagation, + mkldnn::inner_product_forward::desc desc(propagation, data_md, weight_md, out_md); - return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); + return GetFCFwdPd(desc); } } -inline static mkldnn::inner_product_backward_data::primitive_desc GetIpBwdData( +inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData( const NDArray &data, const NDArray &weight, const NDArray &output, - mkldnn::inner_product_forward::primitive_desc ipFwd_pd) { + mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); auto weight_md = GetMemDesc(weight); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md); - return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, ipFwd_pd); + return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, fwd_pd); } -inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWeights( +inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWeights( const NDArray &data, const NDArray &weight, const NDArray *bias, - const NDArray &output, mkldnn::inner_product_forward::primitive_desc ipFwd_pd) { + const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); auto weight_md = GetMemDesc(weight); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); if (bias) { auto bias_md = GetMemDesc(*bias); - mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md, + mkldnn::inner_product_backward_weights::desc desc(data_md, weight_md, bias_md, out_md); return mkldnn::inner_product_backward_weights::primitive_desc( - ipBwdWeights_desc, engine, ipFwd_pd); + desc, engine, fwd_pd); } else { - mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md, + mkldnn::inner_product_backward_weights::desc desc(data_md, weight_md, out_md); return mkldnn::inner_product_backward_weights::primitive_desc( - ipBwdWeights_desc, engine, ipFwd_pd); + desc, engine, fwd_pd); } } -class MKLDNNFullyConnectForward { - std::shared_ptr data; - std::shared_ptr weight; - std::shared_ptr out; - std::shared_ptr bias; - std::shared_ptr ipFwd; - - public: - mkldnn::inner_product_forward::primitive_desc ipFwd_pd; - - MKLDNNFullyConnectForward(const FullyConnectedParam ¶m, bool is_train, - const NDArray &data, const NDArray &weight, - const NDArray *bias, - const mkldnn::memory::desc &output) - : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {} - - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); +void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data, + const mkldnn::memory &weight, + const mkldnn::memory *bias, + const mkldnn::memory &output) { + if (this->data_ == nullptr) + this->data_ = std::shared_ptr(new mkldnn::memory( + fwd_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data_->set_data_handle(data.get_data_handle()); - if (this->weight == nullptr) - this->weight = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight->set_data_handle(weight.get_data_handle()); + if (this->weight_ == nullptr) + this->weight_ = std::shared_ptr(new mkldnn::memory( + fwd_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight_->set_data_handle(weight.get_data_handle()); + + if (this->out_ == nullptr) + this->out_ = std::shared_ptr(new mkldnn::memory( + fwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out_->set_data_handle(output.get_data_handle()); - if (this->out == nullptr) - this->out = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.dst_primitive_desc(), output.get_data_handle())); + if (bias != nullptr) { + if (this->bias_ == nullptr) + this->bias_ = std::shared_ptr(new mkldnn::memory( + fwd_pd.bias_primitive_desc(), bias->get_data_handle())); else - this->out->set_data_handle(output.get_data_handle()); - - if (bias != nullptr) { - if (this->bias == nullptr) - this->bias = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.bias_primitive_desc(), bias->get_data_handle())); - else - this->bias->set_data_handle(bias->get_data_handle()); - if (this->ipFwd == nullptr) - this->ipFwd = std::shared_ptr( - new mkldnn::inner_product_forward( - ipFwd_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->weight), - mkldnn::primitive::at(*this->bias), *this->out)); - } else if (this->ipFwd == nullptr) { - this->ipFwd = std::shared_ptr( + this->bias_->set_data_handle(bias->get_data_handle()); + + if (this->fwd_ == nullptr) + this->fwd_ = std::shared_ptr( new mkldnn::inner_product_forward( - ipFwd_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->weight), *this->out)); + fwd_pd, mkldnn::primitive::at(*this->data_), + mkldnn::primitive::at(*this->weight_), + mkldnn::primitive::at(*this->bias_), *this->out_)); + } else { + if (this->fwd_ == nullptr) { + this->fwd_ = std::shared_ptr( + new mkldnn::inner_product_forward( + fwd_pd, mkldnn::primitive::at(*this->data_), + mkldnn::primitive::at(*this->weight_), *this->out_)); } } - const mkldnn::inner_product_forward &GetIpFwd() const { - return *ipFwd; - } -}; - -typedef ParamOpSign MKLDNNFullyconSignature; +} -static inline MKLDNNFullyConnectForward &GetFCFwd( - const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight, - const NDArray *bias, const mkldnn::memory::desc &output, - const bool is_train) { +MKLDNNFullyConnectedForward &GetFCFwd( + const FullyConnectedParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weight, + const NDArray *bias, const mkldnn::memory::desc &out_md) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fcFwds; + MKLDNNFullyConnectedForward, OpHash> fcFwds; #else static MX_THREAD_LOCAL std::unordered_map fcFwds; + MKLDNNFullyConnectedForward, OpHash> fcFwds; #endif - const FullyConnectedParam& param = nnvm::get(attrs.parsed); MKLDNNFullyconSignature key(param); + key.AddSign(is_train); key.AddSign(data); key.AddSign(weight); - key.AddSign(is_train); - if (bias) key.AddSign(*bias); auto it = fcFwds.find(key); if (it == fcFwds.end()) { - MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias, - output); - auto ins_ret = fcFwds.insert( - std::pair(key, fcFwd)); - CHECK(ins_ret.second); - it = ins_ret.first; + MKLDNNFCFullParam full_param; + full_param.default_param = param; + full_param.mkldnn_param.Init(std::unordered_map()); + MKLDNNFullyConnectedForward fcFwd(full_param, is_train, data, weight, bias, out_md); + it = AddToCache(&fcFwds, key, fcFwd); } return it->second; } -void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { +void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, + const NDArray &out_data, + NDArray *in_data, + mkldnn::memory::desc *out_md) { + const mxnet::TShape ishape = in_data->shape(); + const mxnet::TShape oshape = out_data.shape(); + + // If the input data is a view of an MKLDNN array, we should create a new + // NDArray with reordered data. + if (in_data->IsMKLDNNData() && in_data->IsView()) + *in_data = in_data->Reorder2Default(); + + if (ishape.ndim() != 2) { + if (!param.flatten) { + *in_data = in_data->MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1), + ishape[ishape.ndim()-1])); + mkldnn::memory::dims out_dims{static_cast(oshape.ProdShape(0, oshape.ndim()-1)), + static_cast(oshape[ishape.ndim()-1])}; + *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()), + mkldnn::memory::format::any); + } else { + *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); + mkldnn::memory::dims out_dims{static_cast(oshape[0]), + static_cast(oshape.ProdShape(1, oshape.ndim()))}; + *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()), + mkldnn::memory::format::any); + } + } +} + +void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param, + const OpContext &ctx, + MKLDNNFullyConnectedForward *fwd, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]); - const FullyConnectedParam& param = nnvm::get(attrs.parsed); - const mxnet::TShape& ishape = in_data[fullc::kData].shape(); - const mxnet::TShape& oshape = out_data[fullc::kOut].shape(); NDArray weight = in_data[fullc::kWeight]; NDArray data = in_data[fullc::kData]; - // If the input data is a view of an MKLDNN array, we should create a new - // NDArray with reordered data. - if (data.IsMKLDNNData() && data.IsView()) - data = in_data[fullc::kData].Reorder2Default(); - auto out_md = GetMemDesc(out_data[fullc::kOut]); - if (data.shape().ndim() != 2 && !param.flatten) { - data = data.MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1), - ishape[ishape.ndim()-1])); - mkldnn::memory::dims out_dims{static_cast(oshape.ProdShape(0, oshape.ndim()-1)), - static_cast(oshape[ishape.ndim()-1])}; - out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), - mkldnn::memory::format::any); - } else if (data.shape().ndim() != 2) { - data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); - mkldnn::memory::dims out_dims{static_cast(oshape[0]), - static_cast(oshape.ProdShape(1, oshape.ndim()))}; - out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), - mkldnn::memory::format::any); + auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc()); + const mkldnn::memory *weight_mem; + if (ctx.is_train) { + if (weight.IsMKLDNNData()) { + weight.Reorder2DefaultAsync(); + } + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); + } else { + if (weight.IsDefaultData()) { + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); + weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); + } else { + weight_mem = weight.GetMKLDNNData(); + CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); + } } - MKLDNNFullyConnectForward &FCFwd = - GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias], - out_md, ctx.is_train); - auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc()); auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], - FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data); - if (!param.no_bias) { + fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data); + if (!full_param.default_param.no_bias) { auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder( - FCFwd.ipFwd_pd.bias_primitive_desc()); - FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); + fwd->fwd_pd.bias_primitive_desc()); + fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); } else { - FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); + fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); } - MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd()); + MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd()); CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Get()->Submit(); } +void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + MKLDNNFCFullParam full_param; + full_param.default_param = nnvm::get(attrs.parsed); + full_param.mkldnn_param.Init(std::unordered_map()); + + NDArray data = in_data[fullc::kData]; + mkldnn::memory::desc out_md = GetMemDesc(out_data[fullc::kOut]); + MKLDNNFCFlattenData(full_param.default_param, out_data[fullc::kOut], + &data, &out_md); + auto &fwd = GetFCFwd(full_param.default_param, ctx.is_train, data, + in_data[fullc::kWeight], + full_param.default_param.no_bias ? nullptr : &in_data[fullc::kBias], + out_md); + std::vector new_inputs; + if (full_param.default_param.no_bias) + new_inputs = {data, in_data[fullc::kWeight]}; + else + new_inputs = {data, in_data[fullc::kWeight], in_data[fullc::kBias]}; + MKLDNNFCForwardFullFeature(full_param, ctx, &fwd, new_inputs, req, out_data); +} + void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]); const std::vector &in_grad = outputs; - const FullyConnectedParam& param = nnvm::get(attrs.parsed); + MKLDNNFCFullParam full_param; + full_param.default_param = nnvm::get(attrs.parsed); + full_param.mkldnn_param.Init(std::unordered_map()); + const FullyConnectedParam& param = full_param.default_param; const mxnet::TShape& ishape = inputs[fullc::kData + 1].shape(); const mxnet::TShape& oshape = inputs[fullc::kOut].shape(); @@ -251,13 +327,14 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, out_grad = out_grad.MKLDNNDataReshape(Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim()))); - mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); + + mkldnn::inner_product_forward::primitive_desc fwd_pd = GetFCFwdImpl(full_param, ctx.is_train, + data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad)); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; if (req[fullc::kData]) { - mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData( - data, weight, out_grad, ipFwd_pd); + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData( + data, weight, out_grad, fwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( ipBwdData_pd.diff_dst_primitive_desc()); auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); @@ -270,8 +347,8 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, } if (req[fullc::kWeight]) { mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd - = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], - out_grad, ipFwd_pd); + = GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], + out_grad, fwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( ipBwdWeights_pd.diff_dst_primitive_desc()); auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc new file mode 100644 index 000000000000..36def0002073 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_quantized_fully_connected.cc + * \brief MKLDNN Quantized FullyConnected operator + * \author Ciyong Chen + */ + +#if MXNET_USE_MKLDNN == 1 +#include "../../nn/mkldnn/mkldnn_fully_connected-inl.h" +#include "../quantization_utils.h" + +namespace mxnet { +namespace op { + +namespace quantized_fc_enum { +enum QuantizedFCInputMinMax { kDataMin, kDataMax, kWeightMin, kWeightMax, kBiasMin, kBiasMax }; +enum QuantizedFCOutputs { kOut, kOutMin, kOutMax }; +} + +void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]); + FullyConnectedParam param = nnvm::get(attrs.parsed); + const size_t num_inputs = param.no_bias ? 2 : 3; + + CHECK_EQ(in_data.size(), static_cast(num_inputs * 3)); + CHECK_EQ(out_data.size(), 3U); + + NDArray data = in_data[fullc::kData]; + NDArray weight = in_data[fullc::kWeight]; + const TShape &ishape = data.shape(); + + CHECK(data.dtype() == mshadow::kUint8) + << "MKLDNNQuantizedFullyConnected Op only supports uint8 for now, but got " + << mxnet::op::type_string(data.dtype()); + + if (ishape.ndim() != 2) { + CHECK(param.flatten) + << "QuantizedFullyConnected Op only supports flatten=true when ishape.ndim()!=2 for now."; + data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); + } + + const float min_data = + in_data[num_inputs + quantized_fc_enum::kDataMin].data().dptr()[0]; + const float max_data = + in_data[num_inputs + quantized_fc_enum::kDataMax].data().dptr()[0]; + const float min_weight = + in_data[num_inputs + quantized_fc_enum::kWeightMin].data().dptr()[0]; + const float max_weight = + in_data[num_inputs + quantized_fc_enum::kWeightMax].data().dptr()[0]; + float *min_output_ptr = out_data[quantized_fc_enum::kOutMin].data().dptr(); + float *max_output_ptr = out_data[quantized_fc_enum::kOutMax].data().dptr(); + + auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range; + float data_scale = data_range / MaxAbs(min_data, max_data); + float weight_scale = kInt8Range / MaxAbs(min_weight, max_weight); + + NDArray quantized_bias; + if (!param.no_bias) { + NDArray bias = in_data[fullc::kBias]; + float min_bias = in_data[num_inputs + quantized_fc_enum::kBiasMin].data().dptr()[0]; + float max_bias = in_data[num_inputs + quantized_fc_enum::kBiasMax].data().dptr()[0]; + float bias_int32_rescale = data_scale * weight_scale * MaxAbs(min_bias, max_bias) / kInt8Range; + + quantized_bias = NDArray(bias.storage_type(), bias.shape(), + bias.ctx(), true, mshadow::kInt32); + int8_t *bias_ptr = bias.data().dptr(); + int32_t *quantized_bias_ptr = quantized_bias.data().dptr(); + size_t bias_size = bias.shape().Size(); + #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (size_t i = 0; i < bias_size; ++i) { + quantized_bias_ptr[i] = bias_ptr[i] * bias_int32_rescale; + } + } + + Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch(s, 1, + min_output_ptr, max_output_ptr, &min_data, &max_data, &min_weight, &max_weight); + + bool is_train = false; + mkldnn::memory::desc out_md = GetMemDesc(out_data[fullc::kOut]); + MKLDNNFCFlattenData(param, out_data[fullc::kOut], &data, &out_md); + auto &fwd = GetFCFwd(param, is_train, data, weight, + param.no_bias ? nullptr : &quantized_bias, out_md); + + auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + const mkldnn::memory *weight_mem = nullptr; + + if (weight.IsDefaultData()) { + weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1); + weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); + } else { + weight_mem = weight.GetMKLDNNData(); + CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); + } + auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_primitive_desc(), + req[fullc::kOut]); + const mkldnn::memory *bias_mem = nullptr; + if (!param.no_bias) + bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); + + fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + + CommitOutput(out_data[fullc::kOut], out_mem); + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h new file mode 100644 index 000000000000..88d77c8d0cb2 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_quantized_ops-inl.h + * \brief Common functions used by MKLDNN Quantized FullyConnected operator + * \author Ciyong Chen + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include + +namespace mxnet { +namespace op { + +void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_ diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index 3b18e6591afc..742825c7a477 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -26,6 +26,10 @@ #include #include "quantization_utils.h" #include "../nn/fully_connected-inl.h" +#if MXNET_USE_MKLDNN == 1 +#include "../nn/mkldnn/mkldnn_fully_connected-inl.h" +#include "mkldnn/mkldnn_quantized_ops-inl.h" +#endif namespace mxnet { namespace op { @@ -38,7 +42,6 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) { const FullyConnectedParam& param = nnvm::get(attrs.parsed); - CHECK(param.flatten) << "QuantizedFullyConnectedOp only supports flatten=true for now"; using namespace mshadow; uint32_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_shape->size(), num_inputs * 3); @@ -48,6 +51,10 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, << "QuantizedFullyConnectedOp input data shape must be given"; const mxnet::TShape& dshape = in_shape->at(0); mxnet::TShape wshape = Shape2(param.num_hidden, dshape.ProdShape(1, dshape.ndim())); + if (dshape.ndim() != 2) { + CHECK(param.flatten) + << "QuantizedFullyConnectedOp only supports flatten=true when ishape.ndim()!=2 for now. "; + } SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape); if (!param.no_bias) { mxnet::TShape bshape = Shape1(param.num_hidden); @@ -72,7 +79,14 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_type->size(), num_inputs * 3); CHECK_EQ(out_type->size(), 3U); - for (size_t i = 0; i < num_inputs; ++i) { +#if MXNET_USE_MKLDNN == 1 + // TODO(ciyong): currently, only uint8 fully_connected is upported, + // int8 fully_connected will be supported after mkldnn v0.18 + TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kUint8); +#else + TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); +#endif + for (size_t i = 1; i < num_inputs; ++i) { TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kInt8); } for (size_t i = num_inputs; i < 3 * num_inputs; ++i) { @@ -90,10 +104,16 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs, DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + uint32_t num_inputs = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), num_inputs * 3); + CHECK_EQ(out_attrs->size(), 3U); + +#if MXNET_USE_MKLDNN == 1 + return MKLDNNStorageType(attrs, dev_mask, true, + dispatch_mode, in_attrs, out_attrs); +#else *dispatch_mode = DispatchMode::kFCompute; - if (dev_mask == mshadow::cpu::kDevMask) { - *dispatch_mode = DispatchMode::kFComputeEx; - } for (auto &v : *out_attrs) { v = kDefaultStorage; @@ -109,6 +129,7 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs, } } return true; +#endif } struct QuantizedSumInitKernelWithBias { @@ -137,28 +158,41 @@ struct QuantizedSumInitKernelWithBias { }; -template -void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { +void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { #if MSHADOW_USE_MKL == 1 const FullyConnectedParam& param = nnvm::get(attrs.parsed); using namespace mshadow; using namespace mxnet_op; + Stream *s = ctx.get_stream(); size_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), num_inputs * 3); CHECK_EQ(out_data.size(), 3U); - const NDArray& data = in_data[0]; - const NDArray& weight = in_data[1]; - const NDArray& out = out_data[0]; - mxnet::TShape dshape = data.shape(); - mxnet::TShape wshape = weight.shape(); - mxnet::TShape oshape = out.shape(); - auto output_temp = out.data().dptr(); - auto weight_temp = weight.data().dptr(); - auto data_temp = data.data().dptr(); + + const mxnet::TShape &dshape = in_data[fullc::kData].shape_; + const mxnet::TShape &wshape = in_data[fullc::kWeight].shape_; + const mxnet::TShape &oshape = out_data[fullc::kOut].shape_; + + CHECK(in_data[fullc::kData].type_flag_ == mshadow::kInt8) + << "QuantizedFullyConnectedForwardCPU Op only supports int8 for now, but got " + << mxnet::op::type_string(in_data[fullc::kData].type_flag_); + + if (dshape.ndim() != 2) + CHECK(param.flatten) + << "QuantizedFullyConnectedOp only supports flatten=true when input_shape!=2 for now. "; + + Tensor weight = in_data[fullc::kWeight].get(s); + Tensor data = in_data[fullc::kData].get_with_shape( + Shape2(dshape[0], dshape.ProdShape(1, dshape.ndim())), s); + Tensor out = out_data[fullc::kOut].get_with_shape( + Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s); + + auto data_temp = data.dptr_; + auto weight_temp = weight.dptr_; + auto output_temp = out.dptr_; const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); const float alpha = 1.0f; const float beta = 1.0f; @@ -167,7 +201,6 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, const MKL_INT8 ob = 0; MKL_INT32 oc = 0; const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); - Stream *s = ctx.get_stream(); // cblas_gemm_s8u8s32 required first matrix must be uint8 // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) int shift = 128; @@ -179,16 +212,23 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, shiftdata.dptr_[i] = data_temp[i] + shift; } - Kernel::Launch(s, 1, - out_data[1].data().dptr(), out_data[2].data().dptr(), - in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), - in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); + Tensor min_output = out_data[1].get(s); + Tensor max_output = out_data[2].get(s); + Tensor min_data = in_data[num_inputs].get(s); + Tensor max_data = in_data[num_inputs + 1].get(s); + Tensor min_weight = in_data[num_inputs + 2].get(s); + Tensor max_weight = in_data[num_inputs + 3].get(s); + + Kernel::Launch(s, 1, min_output.dptr_, + max_output.dptr_, min_data.dptr_, max_data.dptr_, min_weight.dptr_, max_weight.dptr_); if (!param.no_bias) { - const NDArray& bias = in_data[2]; - Kernel::Launch(s, n, out.data().dptr(), - bias.data().dptr(), out_data[1].data().dptr(), - out_data[2].data().dptr(), in_data[7].data().dptr(), - in_data[8].data().dptr()); + Tensor bias = in_data[fullc::kBias].get_with_shape( + Shape1(wshape[0]), s); + Tensor min_bias = in_data[num_inputs + 4].get(s); + Tensor max_bias = in_data[num_inputs + 5].get(s); + + Kernel::Launch(s, n, out.dptr_, + bias.dptr_, min_output.dptr_, max_output.dptr_, min_bias.dptr_, max_bias.dptr_); } else { #pragma omp parallel for num_threads(omp_threads) for (int i = 0; i < m * n; ++i) { @@ -216,11 +256,11 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, shiftdata.dptr_, k, oa, - weight.data().dptr(), + weight.dptr_, k, ob, beta, - out.data().dptr(), + out.dptr_, n, &oc); #else @@ -230,6 +270,21 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, #endif } +#if MXNET_USE_MKLDNN == 1 +void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + if (in_data[fullc::kData].dtype() == mshadow::kInt8) { + FallBackCompute(QuantizedFullyConnectedForwardCPU, attrs, ctx, in_data, req, out_data); + return; + } + + MKLDNNQuantizedFullyConnectedForward(attrs, ctx, in_data, req, out_data); +} +#endif + NNVM_REGISTER_OP(_contrib_quantized_fully_connected) .describe(R"code(Fully Connected operator for input, weight and bias data type of int8, and accumulates in type int32 for the output. For each argument, two more arguments of type @@ -268,8 +323,11 @@ and max thresholds representing the threholds for quantizing the float32 output // will be reverted after the improvement of CachedOP is done. .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) -.set_attr("FComputeEx", - QuantizedFullyConnectedForward) +.set_attr("FCompute", QuantizedFullyConnectedForwardCPU) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", QuantizedFullyConnectedForwardExCPU) +#endif .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc new file mode 100644 index 000000000000..94e2bda1e16c --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -0,0 +1,442 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_fc.cc + * \brief MKLDNN (Quantized) FullyConnected operator based on subgraph + * \author Ciyong Chen +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include "../common.h" +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../../nn/mkldnn/mkldnn_ops-inl.h" +#include "../../nn/mkldnn/mkldnn_fully_connected-inl.h" +#include "../../quantization/quantization_utils.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNFCOp { + public: + explicit SgMKLDNNFCOp(const nnvm::NodeAttrs &attrs) + : initialized_(false), + subgraph_sym_(*attrs.subgraphs[0]), + full_param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports " + "inference computation."; + } + + private: + bool initialized_; + nnvm::Symbol subgraph_sym_; + MKLDNNFCFullParam full_param_; + std::shared_ptr fwd_; + NDArray cached_weight_; + NDArray cached_bias_; + float cached_min_data_; + float cached_max_data_; + float cached_min_weight_; + float cached_max_weight_; + float cached_min_bias_; + float cached_max_bias_; +}; + +void SgMKLDNNFCOp::Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + auto &mkldnn_param = full_param_.mkldnn_param; + auto &default_param = full_param_.default_param; + bool has_bias = !default_param.no_bias; + size_t base_num_inputs = has_bias ? 3 : 2; + size_t total_num_inputs = base_num_inputs; + size_t base_num_outputs = 1; + size_t total_num_outputs = base_num_outputs; + + float min_data = 0.0; + float max_data = 0.0; + float min_weight = 0.0; + float max_weight = 0.0; + float min_bias = 0.0; + float max_bias = 0.0; + float *min_output_ptr = nullptr; + float *max_output_ptr = nullptr; + + if (mkldnn_param.quantized) { + total_num_inputs = base_num_inputs * 3; + min_data = in_data[base_num_inputs].data().dptr()[0]; + max_data = in_data[base_num_inputs + 1].data().dptr()[0]; + min_weight = in_data[base_num_inputs + 2].data().dptr()[0]; + max_weight = in_data[base_num_inputs + 3].data().dptr()[0]; + if (has_bias) { + min_bias = in_data[base_num_inputs + 4].data().dptr()[0]; + max_bias = in_data[base_num_inputs + 5].data().dptr()[0]; + } + if (!mkldnn_param.enable_float_output) { + total_num_outputs = base_num_outputs * 3; + min_output_ptr = out_data[1].data().dptr(); + max_output_ptr = out_data[2].data().dptr(); + } + } + CHECK_EQ(in_data.size(), total_num_inputs); + CHECK_EQ(out_data.size(), total_num_outputs); + + NDArray data = in_data[fullc::kData]; + NDArray weight = in_data[fullc::kWeight]; + NDArray output = out_data[fullc::kOut]; + const mxnet::TShape &ishape = data.shape(); + if (mkldnn_param.quantized && ishape.ndim() != 2) { + CHECK(default_param.flatten) + << "QuantizedFullyConnected only supports flatten=true when ishape.ndim() != 2 for now."; + } + + mkldnn::memory::desc out_md = GetMemDesc(output); + MKLDNNFCFlattenData(default_param, out_data[fullc::kOut], &data, &out_md); + + if (initialized_ && mkldnn_param.quantized) { + if (cached_min_data_ != min_data || cached_max_data_ != max_data || + cached_min_weight_ != min_weight || cached_max_weight_ != max_weight || + (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) { + initialized_ = false; + } + } + + if (!initialized_) { + cached_min_data_ = min_data; + cached_max_data_ = max_data; + cached_min_weight_ = min_weight; + cached_max_weight_ = max_weight; + if (has_bias) { + cached_bias_ = in_data[fullc::kBias]; + } else { + cached_bias_ = NDArray(); + } + + if (mkldnn_param.quantized) { + CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8); + auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range; + float data_scale = data_range / MaxAbs(cached_min_data_, cached_max_data_); + float weight_scale = kInt8Range / MaxAbs(cached_min_weight_, cached_max_weight_); + float quantized_out_range = mkldnn_param.with_relu ? kUint8Range : kInt8Range; + + if (has_bias) { + NDArray bias = in_data[fullc::kBias]; + float bias_int32_rescale = data_scale * weight_scale * + MaxAbs(min_bias, max_bias) / kInt8Range; + + cached_bias_ = NDArray(bias.storage_type(), bias.shape(), + bias.ctx(), true, mshadow::kInt32); + int8_t *bias_ptr = bias.data().dptr(); + int32_t *quantized_bias_ptr = cached_bias_.data().dptr(); + size_t bias_size = bias.shape().Size(); + #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + for (size_t i = 0; i < bias_size; ++i) { + quantized_bias_ptr[i] = bias_ptr[i] * bias_int32_rescale; + } + } + + if (mkldnn_param.enable_float_output) { + full_param_.output_scales[0] = 1.0 / data_scale / weight_scale; + full_param_.requantize_scales.resize(0); + } else if (mkldnn_param.min_calib_range.has_value() && + mkldnn_param.max_calib_range.has_value()) { + full_param_.output_scales.resize(0); + *min_output_ptr = mkldnn_param.min_calib_range.value(); + *max_output_ptr = mkldnn_param.max_calib_range.value(); + + full_param_.requantize_scales[0] = quantized_out_range / + MaxAbs(*min_output_ptr, *max_output_ptr) / data_scale / weight_scale; + } else { + Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch(s, 1, + min_output_ptr, max_output_ptr, &min_data, &max_data, &min_weight, &max_weight); + } + } + + fwd_.reset(new MKLDNNFullyConnectedForward(full_param_, ctx.is_train, data, weight, + (has_bias ? &cached_bias_ : nullptr), out_md)); + initialized_ = true; + } + std::vector new_inputs; + std::vector new_req; + if (has_bias) { + new_inputs = {data, weight, cached_bias_}; + new_req = {req[fullc::kData], req[fullc::kWeight], req[fullc::kBias]}; + } else { + new_inputs = {data, weight}; + new_req = {req[fullc::kData], req[fullc::kWeight]}; + } + + MKLDNNFCForwardFullFeature(full_param_, ctx, fwd_.get(), new_inputs, new_req, out_data); +} + +static void SgMKLDNNFCParamParser(nnvm::NodeAttrs *attrs) { + MKLDNNFCFullParam full_param; + try { + full_param.mkldnn_param.Init(attrs->dict); + } catch (const dmlc::ParamError &e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto &k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + auto subgraph_sym = attrs->subgraphs[0]; + DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &node_name = node->op()->name; + if (node_name == "FullyConnected") { + full_param.default_param = + nnvm::get(node->attrs.parsed); + } + }); + attrs->parsed = std::move(full_param); +} + +static std::vector SgMKLDNNFCListInputNames(const NodeAttrs &attrs) { + auto const &full_param = nnvm::get(attrs.parsed); + std::vector input_names = DefaultSubgraphOpListInputs(attrs); + if (full_param.mkldnn_param.quantized) { + input_names.emplace_back("min_data"); + input_names.emplace_back("max_data"); + input_names.emplace_back("min_weight"); + input_names.emplace_back("max_weight"); + if (!full_param.default_param.no_bias) { + input_names.emplace_back("min_bias"); + input_names.emplace_back("max_bias"); + } + } + return input_names; +} + +static std::vector SgMKLDNNFCListOutputNames(const NodeAttrs &attrs) { + auto const &full_param = nnvm::get(attrs.parsed); + if (full_param.mkldnn_param.quantized) { + if (full_param.mkldnn_param.enable_float_output) + return std::vector{"output"}; + else + return std::vector{"output", "min_output", "max_output"}; + } else { + return std::vector{"output"}; + } +} + +template +static inline void FillBaseInputOutputInfo(const FullyConnectedParam ¶m, + std::vector *base_in_attrs, + std::vector *base_out_attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + auto base_num_inputs = param.no_bias ? 2 : 3; + + base_out_attrs->push_back(out_attrs->at(0)); + for (int i = 0; i < base_num_inputs; ++i) { + base_in_attrs->push_back(in_attrs->at(i)); + } +} + +static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs, + mxnet::ShapeVector *in_shapes, + mxnet::ShapeVector *out_shapes) { + auto const &full_param = nnvm::get(attrs.parsed); + if (full_param.mkldnn_param.quantized) { + mxnet::ShapeVector base_in_shapes; + mxnet::ShapeVector base_out_shapes; + FillBaseInputOutputInfo(full_param.default_param, &base_in_shapes, &base_out_shapes, + in_shapes, out_shapes); + bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); + + for (size_t i = 0; i < in_shapes->size(); ++i) { + if (i < base_in_shapes.size()) + in_shapes->at(i) = base_in_shapes[i]; + else + SHAPE_ASSIGN_CHECK(*in_shapes, i, Shape1(1)); + } + + out_shapes->at(0) = base_out_shapes[0]; + if (!full_param.mkldnn_param.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1)); + SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1)); + } + return ret; + } else { + return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes); + } +} + +static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + auto const &full_param = nnvm::get(attrs.parsed); + if (full_param.mkldnn_param.quantized) { + size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3; + + // TODO(ciyong): currently, only uint8 fully_connected is upported, + // int8 fully_connected will be supported after mkldnn v0.18 + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kUint8); + for (size_t i = 1; i < in_types->size(); ++i) { + if (i < base_num_inputs) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8); + } else { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } + } + + if (full_param.mkldnn_param.enable_float_output) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); + } else { + if (full_param.mkldnn_param.min_calib_range.has_value() && + full_param.mkldnn_param.max_calib_range.has_value()) { + if (full_param.mkldnn_param.with_relu) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + } + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); + } + return true; + } else { + return DefaultSubgraphOpType(attrs, in_types, out_types); + } +} + +static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + auto const &full_param = nnvm::get(attrs.parsed); + if (full_param.mkldnn_param.quantized) { + std::vector base_in_attrs; + std::vector base_out_attrs; + FillBaseInputOutputInfo(full_param.default_param, &base_in_attrs, &base_out_attrs, + in_attrs, out_attrs); + bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + &base_in_attrs, &base_out_attrs); + + for (size_t i = 0; i < in_attrs->size(); ++i) { + if (i < base_in_attrs.size()) + in_attrs->at(i) = base_in_attrs[i]; + else + type_assign(&in_attrs->at(i), mxnet::kDefaultStorage); + } + + out_attrs->at(0) = base_out_attrs[0]; + if (!full_param.mkldnn_param.enable_float_output) { + type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); + type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); + } + return ret; + } else { + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); + } +} + +static OpStatePtr CreateSgMKLDNNFCState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +static void SgMKLDNNFCForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNFCOp &op = state_pointer.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +nnvm::NodePtr SgMKLDNNFCQuantizedOp(const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_sg_mkldnn_fully_connected"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["quantized"] = "True"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + for (auto sub : attrs.subgraphs) { + node->attrs.subgraphs.push_back(sub); + } + node->op()->attr_parser(&(node->attrs)); + return node; +} + +NNVM_REGISTER_OP(_sg_mkldnn_fully_connected) +.describe(R"code(_sg_mkldnn_fully_connected)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const &full_param = nnvm::get(attrs.parsed); + auto num_inputs = full_param.default_param.no_bias ? 2 : 3; + if (full_param.mkldnn_param.quantized) + return num_inputs * 3; + else + return num_inputs; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const &full_param = nnvm::get(attrs.parsed); + return (full_param.mkldnn_param.quantized && + !full_param.mkldnn_param.enable_float_output) ? 3 : 1; +}) +.set_attr_parser(SgMKLDNNFCParamParser) +.set_attr("FListInputNames", SgMKLDNNFCListInputNames) +.set_attr("FListOutputNames", SgMKLDNNFCListOutputNames) +.set_attr("FInferShape", SgMKLDNNFCInferShape) +.set_attr("FInferType", SgMKLDNNFCInferType) +.set_attr("FInferStorageType", SgMKLDNNFCStorageType) +.set_attr("FCreateOpState", CreateSgMKLDNNFCState) +.set_attr("FStatefulComputeEx", SgMKLDNNFCForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FMutateInputs", + DefaultSubgraphOpMutableInputs) +.set_attr("key_var_num_args", "num_args") +.set_attr("FQuantizedOp", SgMKLDNNFCQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc new file mode 100644 index 000000000000..d2d176fadbb6 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_fc_post_quantize_property.cc + * \brief Partition gragph property for MKLDNN Quantized FullyConnected operator + * \author Ciyong Chen +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include "../common.h" +#include "../subgraph_property.h" +#include "../../nn/fully_connected-inl.h" +#include "../../quantization/requantize-inl.h" + +namespace mxnet { +namespace op { + +#define QUANTIZED_FC_NAME "_sg_mkldnn_fully_connected" + +class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kRequantize, + kSuccess, + }; + + private: + bool disable_all; + bool disable_float_output; + SelectStatus status; + std::vector matched_list; + + public: + explicit SgMKLDNNFCPostQuantizeSelector(const bool dis_all, + const bool dis_float_output) + : disable_all(dis_all), + disable_float_output(dis_float_output) {} + + bool Select(const nnvm::Node &n) override { + if ((!disable_all) && n.op() == Op::Get(QUANTIZED_FC_NAME)) { + status = disable_all ? kSuccess : kStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + if (std::find(matched_list.begin(), matched_list.end(), &n) != + matched_list.end()) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + } + + status = kSuccess; + return false; + } + + switch (status) { + case kStart: + if (new_node.op() == Op::Get("_contrib_requantize")) { + auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (param.min_calib_range.has_value() && + param.max_calib_range.has_value()) { + matched_list.push_back(&new_node); + status = kRequantize; + return true; + } + } + case kRequantize: + if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { + matched_list.push_back(&new_node); + status = kSuccess; + return true; + } + default: + status = kSuccess; + return false; + } + } + + std::vector Filter( + const std::vector &candidates) override { + if ((status != kSuccess) || (matched_list.size() <= 1)) { + return std::vector(0); + } else { + std::vector ret; + for (auto i : matched_list) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; + } + } +}; + +class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { + public: + SgMKLDNNFCPostQuantizeProperty() { + disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_POST_OPT", false); + disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL", false); + disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT", false); + + disable_all = disable_all || disable_fuse_all; + if (disable_all) { + LOG(INFO) << "MKLDNN FullyConnected post-quantization optimization pass is disabled."; + } else { + LOG(INFO) << "Start to execute MKLDNN FullyConected post-quantization optimization pass."; + } + } + + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::NodePtr fc_node = nullptr; + nnvm::NodePtr requantize_node = nullptr; + nnvm::NodePtr dequantize_node = nullptr; + + DFSVisit(sym.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + if (node->op() == Op::Get(QUANTIZED_FC_NAME)) { + fc_node = node; + } else if (node->op() == Op::Get("_contrib_requantize")) { + requantize_node = node; + } else if (node->op() == Op::Get("_contrib_dequantize")) { + dequantize_node = node; + } + }); + + CHECK_NOTNULL(fc_node); + CHECK_NOTNULL(requantize_node); + auto const &requantize_param = + nnvm::get(requantize_node->attrs.parsed); + CHECK(requantize_param.min_calib_range.has_value()); + CHECK(requantize_param.max_calib_range.has_value()); + + // When only fused quantized_fullyconnected and requantize, set min/max_cablib_range, + // When fused quantized_fullyconnected + requantize + dequantize, set dequantize flag to true. + if (dequantize_node != nullptr) { + fc_node->attrs.dict["enable_float_output"] = "True"; + } else { + fc_node->attrs.dict["min_calib_range"] = + std::to_string(requantize_param.min_calib_range.value()); + fc_node->attrs.dict["max_calib_range"] = + std::to_string(requantize_param.max_calib_range.value()); + } + fc_node->op()->attr_parser(&(fc_node->attrs)); + return fc_node; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = + std::make_shared(disable_all, + disable_float_output); + return selector; + } + + void ConnectSubgraphOutputs( + const nnvm::NodePtr n, + std::vector *output_entries) const override { + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } + + private: + bool disable_all; + bool disable_fuse_all; + bool disable_float_output; +}; + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_FC_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc b/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc new file mode 100644 index 000000000000..e4fa02d4e713 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_fc_property.cc + * \brief Partition gragph property for FullyConnected operator + * \author Ciyong Chen +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include "../common.h" +#include "../subgraph_property.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNFCSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kSuccess, + }; + + private: + bool disable_all; + bool disable_fc_relu; + SelectStatus status; + std::vector matched_list; + + public: + SgMKLDNNFCSelector(const bool dis_all, const bool dis_fc_relu) + : disable_all(dis_all), + disable_fc_relu(dis_fc_relu) {} + + bool Select(const nnvm::Node &n) override { + if (n.op() == Op::Get("FullyConnected")) { + status = disable_all ? kSuccess : kStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + if (std::find(matched_list.begin(), matched_list.end(), &n) != + matched_list.end()) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + } + + status = kSuccess; + return false; + } + + switch (status) { + case kStart: + if ((!disable_fc_relu) && + new_node.op() == Op::Get("Activation") && + new_node.attrs.dict.at("act_type") == "relu") { + matched_list.push_back(&new_node); + status = kSuccess; + return true; + } + default: + status = kSuccess; + return false; + } + } + + std::vector Filter( + const std::vector &candidates) override { + if (status == kFail) { + return std::vector(0); + } else { + std::vector ret; + for (auto i : matched_list) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return candidates; + } + } +}; + +class SgMKLDNNFCProperty : public SubgraphProperty { + public: + SgMKLDNNFCProperty() { + disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", false); + disable_fc_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_FC_RELU", false); + + disable_all = disable_all || disable_fc_relu; + if (disable_all) { + LOG(INFO) << "MKLDNN FullyConnected optimization pass is disabled."; + } else { + LOG(INFO) << "Start to execute MKLDNN FullyConnected optimization pass."; + } + } + + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::NodePtr n = nnvm::Node::Create(); + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0}); + std::ostringstream node_name; + node_name << "sg_mkldnn_"; + DFSVisit(new_sym.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &sub_name = node->op()->name; + if (sub_name == "FullyConnected") { + node_name << "fully_connected_"; + } else if ((sub_name == "Activation") && + (node->attrs.dict.at("act_type") == "relu")) { + node_name << "relu_"; + n->attrs.dict["with_relu"] = "True"; + } + }); + node_name << std::to_string(subgraph_id); + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get("_sg_mkldnn_fully_connected"); + CHECK(n->attrs.op); + n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + n->op()->attr_parser(&(n->attrs)); + return n; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared( + disable_all, disable_fc_relu); + return selector; + } + + void ConnectSubgraphOutputs( + const nnvm::NodePtr n, + std::vector *output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } + + private: + bool disable_all; + bool disable_fc_relu; +}; + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_FC, SgMKLDNNFCProperty); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 8de854cc290d..e6fe0011af19 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -32,17 +32,40 @@ sys.path.append(os.path.join(curr_path, '../unittest/')) from common import with_seed from mxnet.test_utils import assert_almost_equal +import itertools + +OP_NAME='op_name' +QUANTIZED_OP_NAME='quantized_op_name' +SG_PASS_NAME='sg_pass_name' +POST_SG_PASS_NAME='post_sg_pass_name' +config = { + 'conv': { + OP_NAME: 'sg_mkldnn_conv', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv', + SG_PASS_NAME: 'MKLDNN', + POST_SG_PASS_NAME: 'MKLDNN_POST_QUANTIZE' + }, + 'fc': { + OP_NAME: 'sg_mkldnn_fully_connected', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected', + SG_PASS_NAME: 'MKLDNN_FC', + POST_SG_PASS_NAME: 'MKLDNN_POST_FC_QUANTIZE' + } +} DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] -def check_qsym_calibrated(qsym, out_type): - assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 +def check_qsym_calibrated(qsym, out_type, name='conv'): + quantized_op_name = config[name][QUANTIZED_OP_NAME] + assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1 for k, v in qsym.attr_dict().items(): - if k.find('quantized_sg_mkldnn_conv') != -1: - assert 'min_calib_range' in v - assert 'max_calib_range' in v if k.find('_quantize') != -1: assert v['out_type'] == out_type + if k.find(quantized_op_name) != -1: + if name == 'fc' and 'enable_float_output' in v: + continue + assert 'min_calib_range' in v + assert 'max_calib_range' in v def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape): mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) @@ -81,22 +104,27 @@ def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape): data = mx.random.uniform(-1.0, 1.0, shape=data_shape) net(data) -def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=False): - fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') +def check_quantize(sym, data_shape, out_type, name='conv', + check_calibration=True, gluon_forward=False): + sg_pass_name = config[name][SG_PASS_NAME] + post_sg_pass_name = config[name][POST_SG_PASS_NAME] + + fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc_softmax') if gluon_forward == True: sym = fc - sym_sg = fc.get_backend_symbol("MKLDNN") + sym_sg = sym.get_backend_symbol(sg_pass_name) mod = Module(symbol=sym, label_names=[]) mod.bind(for_training=False, data_shapes=[('data', data_shape)]) else: sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - sym_sg = sym.get_backend_symbol("MKLDNN") + sym_sg = sym.get_backend_symbol(sg_pass_name) label_shape = (data_shape[0], 10) mod = Module(symbol=sym) mod.bind(for_training=False, data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) + mod.init_params(mx.init.Normal(0.5)) arg_params, aux_params = mod.get_params() @@ -108,9 +136,12 @@ def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=Fal output.wait_to_read() ref_out = mod.get_outputs() + # TODO(ciyong), exclude the second fc due to int8 fully_connected is not + # supported before mkldnn 0.18 excluded_sym_names = [] if mx.current_context() == mx.cpu(): - excluded_sym_names += ['fc'] + excluded_sym_names += ['fc_softmax'] + excluded_sym_names += ['sg_mkldnn_fully_connected_1'] calib_data = mx.nd.random.uniform(shape=data_shape) calib_data = NDArrayIter(data=calib_data) @@ -126,9 +157,9 @@ def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=Fal calib_data=calib_data, calib_layer=calib_layer, num_calib_examples=5) - qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") - if check_conv: - check_qsym_calibrated(qsym, out_type) + qsym = qsym.get_backend_symbol(post_sg_pass_name) + if check_calibration: + check_qsym_calibrated(qsym, out_type, name=name) if gluon_forward == True: check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape) else: @@ -137,22 +168,24 @@ def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=Fal for i in range(len(ref_out)): assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) - @with_seed() -def check_fusion(sym, data_shape, attrs_op): - sym_sg = sym.get_backend_symbol("MKLDNN") - assert ''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1 +def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True): + op_name = config[name][OP_NAME] + sg_pass_name = config[name][SG_PASS_NAME] + + sym_sg = sym.get_backend_symbol(sg_pass_name) + assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1 for k, v in sym_sg.attr_dict().items(): - if k.find('sg_mkldnn_conv') != -1: + if k.find(op_name) != -1: for attr_op in attrs_op: - assert v[attr_op] == 'true' + assert v[attr_op] in ['true', 'True'] arg_shapes, _, aux_shapes = sym.infer_shape() arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes] aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') exe.forward() - os.environ['MXNET_SUBGRAPH_BACKEND'] = 'MKLDNN' + os.environ['MXNET_SUBGRAPH_BACKEND'] = sg_pass_name exe_sg = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') exe_sg.forward() del os.environ['MXNET_SUBGRAPH_BACKEND'] @@ -160,18 +193,33 @@ def check_fusion(sym, data_shape, attrs_op): assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3) # fp32 to int8 - for out_type in ('uint8', 'int8', 'auto'): - check_quantize(sym, data_shape, out_type) - check_quantize(sym, data_shape, out_type, gluon_forward=True) + # TODO(ciyong), int8 fully_connected will be supported after mkldnn 0.18 + if name == 'fc': + out_type_list = ['uint8', 'auto'] + else: + out_type_list = ['uint8', 'int8', 'auto'] + + if check_quantization: + for out_type in out_type_list: + check_quantize(sym, data_shape, out_type, name=name) + # TODO(ciyong), since quantized fc save its params in int8, while gluon treat the default + # variable from symbol file as fp32 which results in mismatch dtype of params. + # Skip quantized fc in gluon pass. + if name != 'fc': + check_quantize(sym, data_shape, out_type, name=name, gluon_forward=True) + +def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, + date_shape=(4,4,10,10), name='conv'): + op_name = config[name][OP_NAME] + sg_pass_name = config[name][SG_PASS_NAME] -def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)): for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): - sym_sg = sym.get_backend_symbol("MKLDNN") + sym_sg = sym.get_backend_symbol(sg_pass_name) exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') attrs_dict = sym_sg.attr_dict() for k, v in attrs_dict.items(): - if k.find('sg_mkldnn_conv') != -1: + if k.find(op_name) != -1: for attr in attrs: assert v[attr] == 'true' for exc_attr in excluded_attr: @@ -443,6 +491,45 @@ def neg_conv_bn_add_relu(data_shape): excluded_attrs.append(['with_postsum_relu']) return syms, attrs, excluded_attrs +def single_fc(no_bias, data_shape, flatten=True): + attr = [''] + data, weight = head_symbol(data_shape) + fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, + no_bias=no_bias, flatten=flatten) + return fc, attr + +def fc_relu(no_bias, data_shape, flatten=True): + attr = ['with_relu'] + data, weight = head_symbol(data_shape) + fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, + no_bias=no_bias, flatten=flatten) + relu = mx.symbol.Activation(data=fc, name='relu', act_type="relu") + return relu, attr + +# fc + relu can't be fusion case +# eg.1 +# fc -----------> relu +# | +# | +# ---------------> [custom op] +def neg_fc_relu(no_bias, data_shape, flatten=True): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool) + fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, + no_bias=no_bias, flatten=flatten) + relu = mx.symbol.Activation(data=fc, name='relu', act_type="relu") + sigmoid = mx.symbol.Activation(data=fc, name='sigmoid', act_type="sigmoid") + sym = tail_neg_symbol(relu, sigmoid) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + @with_seed() def test_pos_single_conv(): for data_shape in DATA_SHAPE: @@ -503,14 +590,14 @@ def test_pos_single_concat(): for data_shape in DATA_SHAPE: for out_type in ('uint8', 'int8', 'auto'): net = single_concat(data_shape, 2, 1) - check_quantize(net, data_shape, out_type, False) - check_quantize(net, data_shape, out_type, False, True) + check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) + check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) net = single_concat(data_shape, 4, 2) - check_quantize(net, data_shape, out_type, False) - check_quantize(net, data_shape, out_type, False, True) + check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) + check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) net = single_concat(data_shape, 4, 3) - check_quantize(net, data_shape, out_type, False) - check_quantize(net, data_shape, out_type, False, True) + check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) + check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) @with_seed() def test_neg_conv_bn(): @@ -542,6 +629,30 @@ def test_neg_conv_bn_add_relu(): syms, attrs, excluded_attrs = neg_conv_bn_add_relu(data_shape) check_neg_fusion(syms, attrs, excluded_attrs, data_shape) +@with_seed() +def test_single_fc(): + for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): + syms, attrs = single_fc(no_bias, dshape, flatten) + if flatten is True: + check_fusion(syms, dshape, attrs, name='fc', check_quantization=True) + else: + check_fusion(syms, dshape, attrs, name='fc', check_quantization=False) + + +@with_seed() +def test_fc_relu(): + for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): + syms, attrs = fc_relu(no_bias, dshape, flatten) + if flatten is True: + check_fusion(syms, dshape, attrs, name='fc', check_quantization=True) + else: + check_fusion(syms, dshape, attrs, name='fc', check_quantization=False) + +@with_seed() +def test_neg_fc_relu(): + for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): + syms, attrs, excluded_attrs = neg_fc_relu(no_bias, dshape, flatten) + check_neg_fusion(syms, attrs, excluded_attrs, dshape, name='fc') if __name__ == "__main__": import nose diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index d8c7f08d4ca5..e4cc277d0ae3 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -278,7 +278,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p @with_seed() def test_quantized_fc(): def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): - if mx.current_context().device_type != 'gpu': + if is_test_for_native_cpu(): hasMKL = False; for key in os.environ.keys(): if operator.eq(key, "BUILD_TAG"): @@ -288,31 +288,62 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): if hasMKL == False: print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library') return + elif qdtype == 'int8' and is_test_for_mkldnn(): + print('skipped testing test_quantized_fc for mkldnn cpu int8 since it is not supported yet') + return elif qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing quantized_fc for gpu uint8 since it is not supported yet') return + def maxabs(a, b): + return mx.nd.maximum(mx.nd.abs(a), mx.nd.abs(b)) + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') fc_fp32 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten) arg_shapes, _, _ = fc_fp32.infer_shape(data=data_shape) arg_names = fc_fp32.list_arguments() fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') + int8_range = 127.0 if qdtype == 'uint8': data_low = 0.0 data_high = 63.0 + quantized_range = 255.0 else: data_low = -63.0 data_high = 63.0 - fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, - shape=data_shape).astype('int32') - fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high, - shape=arg_shapes[1]).astype('int32') + quantized_range = 127.0 + + data = mx.nd.random.uniform(low=data_low, high=data_high, + shape=data_shape).astype('int32') + weight = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[1]).astype('int32') + fc_fp32_exe.arg_dict[arg_names[0]][:] = data + fc_fp32_exe.arg_dict[arg_names[1]][:] = weight + + data_min = mx.nd.min(data).astype('float32') + data_max = mx.nd.max(data).astype('float32') + weight_min = mx.nd.min(weight).astype('float32') + weight_max = mx.nd.max(weight).astype('float32') + data_range = maxabs(data_min, data_max) + weight_range = maxabs(weight_min, weight_max) + if not no_bias: - fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high, - shape=arg_shapes[2]).astype('int32') + bias = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[2]).astype('int32') + bias_min = mx.nd.min(bias).astype('float32') + bias_max = mx.nd.max(bias).astype('float32') + bias_range = maxabs(bias_min, bias_max) + + bias_scale = int8_range / bias_range + data_scale = quantized_range / data_range + weight_scale = int8_range / weight_range + bias_int32_rescale = data_scale * weight_scale / bias_scale + new_bias = mx.nd.cast(bias, dtype='float32') * bias_int32_rescale + fc_fp32_exe.arg_dict[arg_names[2]][:] = new_bias.astype('int32') + output = fc_fp32_exe.forward()[0] - qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8') + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype) fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten) qarg_names = fc_int8.list_arguments() @@ -322,20 +353,19 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): fc_int8_exe = fc_int8.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null') fc_int8_exe.arg_dict[qarg_names[0]][:] = fc_fp32_exe.arg_dict[arg_names[0]].astype(qdtype) fc_int8_exe.arg_dict[qarg_names[1]][:] = fc_fp32_exe.arg_dict[arg_names[1]].astype('int8') - quantized_range = 127.0 if no_bias: - fc_int8_exe.arg_dict[qarg_names[2]][:] = -quantized_range - fc_int8_exe.arg_dict[qarg_names[3]][:] = quantized_range - fc_int8_exe.arg_dict[qarg_names[4]][:] = -quantized_range - fc_int8_exe.arg_dict[qarg_names[5]][:] = quantized_range + fc_int8_exe.arg_dict[qarg_names[2]][:] = -data_range + fc_int8_exe.arg_dict[qarg_names[3]][:] = data_range + fc_int8_exe.arg_dict[qarg_names[4]][:] = -weight_range + fc_int8_exe.arg_dict[qarg_names[5]][:] = weight_range else: - fc_int8_exe.arg_dict[qarg_names[2]][:] = fc_fp32_exe.arg_dict[arg_names[2]].astype('int8') - fc_int8_exe.arg_dict[qarg_names[3]][:] = -quantized_range - fc_int8_exe.arg_dict[qarg_names[4]][:] = quantized_range - fc_int8_exe.arg_dict[qarg_names[5]][:] = -quantized_range - fc_int8_exe.arg_dict[qarg_names[6]][:] = quantized_range - fc_int8_exe.arg_dict[qarg_names[7]][:] = -quantized_range - fc_int8_exe.arg_dict[qarg_names[8]][:] = quantized_range + fc_int8_exe.arg_dict[qarg_names[2]][:] = bias.astype('int8') + fc_int8_exe.arg_dict[qarg_names[3]][:] = -data_range + fc_int8_exe.arg_dict[qarg_names[4]][:] = data_range + fc_int8_exe.arg_dict[qarg_names[5]][:] = -weight_range + fc_int8_exe.arg_dict[qarg_names[6]][:] = weight_range + fc_int8_exe.arg_dict[qarg_names[7]][:] = -bias_range + fc_int8_exe.arg_dict[qarg_names[8]][:] = bias_range qoutput, min_range, max_range = fc_int8_exe.forward() if no_bias: