From 315c580a81d8a3a8b9fad96ba08be2072de7c3b8 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 26 Sep 2019 09:47:20 +0800 Subject: [PATCH 1/3] int8 conv quantize dequantize requantize Change-Id: Ibd9df97288a95c61d6d85ec3831fd18b626ca283 --- include/mxnet/ndarray.h | 2 +- src/executor/graph_executor.cc | 2 +- src/ndarray/ndarray.cc | 12 +- src/operator/nn/mkldnn/mkldnn_base-inl.h | 15 +++ .../nn/mkldnn/mkldnn_convolution-inl.h | 4 + src/operator/quantization/dequantize.cc | 8 +- .../mkldnn/mkldnn_dequantize-inl.h | 45 ++++---- .../quantization/mkldnn/mkldnn_quantize-inl.h | 44 +++----- .../mkldnn/mkldnn_quantize_v2-inl.h | 43 +++----- .../mkldnn/mkldnn_quantized_conv.cc | 35 +++--- .../mkldnn/mkldnn_requantize-inl.h | 26 ++--- src/operator/quantization/quantize.cc | 6 +- src/operator/quantization/quantize_v2.cc | 8 +- src/operator/quantization/quantized_conv.cc | 4 +- src/operator/quantization/requantize.cc | 6 +- .../subgraph/mkldnn/mkldnn_conv-inl.h | 6 +- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 103 +++++++++--------- .../subgraph/mkldnn/mkldnn_conv_property.h | 2 +- .../mkldnn/mkldnn_post_quantize_property.h | 2 +- .../mkldnn/mkldnn_subgraph_property.cc | 22 ++-- 20 files changed, 195 insertions(+), 200 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 16bb7e40b7a2..def1684d661c 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -787,7 +787,7 @@ class NDArray { /*! * \ Fix mkldnn memory descriptor mismatch from NDArray. */ - void UpdateMKLDNNMemDesc(mkldnn::memory::format_tag format); + void UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc); #endif /*! diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index d92253266f35..8099b391e3ff 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -42,7 +42,7 @@ namespace exec { using namespace mxnet::common; static const std::string GetDefaultSubgraphBackend() { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 return std::string("MKLDNN"); #else return std::string(); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index f174d75553b0..b27586c1f0b6 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -724,16 +724,14 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) { return ptr_->mkl_mem_->GetRaw(); } -void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format_tag format) { - const mkldnn::memory *mem = GetMKLDNNData(); - auto mem_desc = mem->get_desc(); +void NDArray::UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc) { + auto new_desc = desc; auto this_dtype = get_mkldnn_type(dtype()); - mkldnn::memory::desc data_md( - mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), - this_dtype, format); - ptr_->mkl_mem_.reset(new MKLDNNMemory(data_md, ptr_->shandle.dptr)); + new_desc.data.data_type = static_cast(this_dtype); + ptr_->mkl_mem_.reset(new MKLDNNMemory(new_desc, ptr_->shandle.dptr)); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); } + #endif void NDArray::SetTBlob() const { diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 0148a6e83122..de345c8af5e4 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -220,6 +220,21 @@ static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) { } } +template +static inline mkldnn::memory::data_type get_mkldnn_type() { + return static_cast(data_type_enum::type); +} + +static inline mkldnn_data_type_t get_mkldnn_type_t(int dtype) { + return static_cast(get_mkldnn_type(dtype)); +} + +template +static inline mkldnn_data_type_t get_mkldnn_type_t() { + return static_cast(data_type_enum::type); +} + + static inline int get_mxnet_type(mkldnn_data_type_t dtype) { auto mkldnn_dtype = static_cast(dtype); switch (mkldnn_dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index b3ceb9595730..9c7cbb73f58b 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -99,6 +99,10 @@ class MKLDNNConvForward { typedef ParamOpSign MKLDNNConvSignature; +MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weight, const NDArray *bias, + const NDArray &output); + void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const OpContext &ctx, MKLDNNConvForward *fwd, diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc index e8e2cd90b86c..3041db94dc95 100644 --- a/src/operator/quantization/dequantize.cc +++ b/src/operator/quantization/dequantize.cc @@ -23,7 +23,7 @@ * \brief */ #include "./dequantize-inl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_dequantize-inl.h" #endif @@ -37,7 +37,7 @@ bool DequantizeStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask) { *dispatch_mode = DispatchMode::kFComputeEx; } @@ -55,7 +55,7 @@ static OpStatePtr CreateDequantizeState(const nnvm::NodeAttrs &attrs, Context ct if (ctx.dev_type == kGPU) { state = OpStatePtr::Create>(attrs); } else { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 state = OpStatePtr::Create(attrs); #else state = OpStatePtr::Create>(attrs); @@ -95,7 +95,7 @@ by keep zero centered for the quantized value: // will be reverted after the improvement of CachedOP is done. .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FCreateOpState", CreateDequantizeState) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", SgMKLDNNDequantizeForward) #endif diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h index 27fa070afbe0..ae891dd45712 100644 --- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h @@ -25,7 +25,7 @@ #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_ #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include @@ -48,8 +48,8 @@ class SgMKLDNNDequantizeOperator { DequantizeParam param_; float cached_data_min_{0.f}; float cached_data_max_{0.f}; - std::shared_ptr i_mem_; - std::shared_ptr o_mem_; + mkldnn::memory::desc o_desc_; + mkldnn_args_map_t args_; std::shared_ptr fwd_pd_; }; @@ -79,37 +79,30 @@ void SgMKLDNNDequantizeOperator::Forward(const OpContext &ctx, const std::vector LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type"; } float scale = real_range / quantized_range; - primitive_attr attr; + mkldnn::primitive_attr attr; const int mask = 0; std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); + auto i_desc = i_mem->get_desc(); size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); - } - mkldnn::memory::format o_fmt = static_cast(i_desc.data.format); - if (o_fmt == mkldnn::memory::format::nhwc) { - // For 4d tensor, nchw is the default format - o_fmt = mkldnn::memory::format::nchw; + if (i_ndim == 4) { + mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nchw; + mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims); + o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type(), o_fmt); + } else { + o_desc_ = i_desc; + o_desc_.data.data_type = get_mkldnn_type_t(); } - auto o_desc = - mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum::type, o_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - i_mem_ = std::make_shared(i_mpd, nullptr); - o_mem_ = std::make_shared(o_mpd, nullptr); - fwd_pd_ = std::make_shared(reorder_pd, *i_mem_, *o_mem_); + auto reorder_pd = + mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr); + fwd_pd_ = std::make_shared(reorder_pd); initialized_ = true; } - auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]); - i_mem_->set_data_handle(i_mem->get_data_handle()); - o_mem_->set_data_handle(o_mem.second->get_data_handle()); - MKLDNNStream::Get()->RegisterPrim(*fwd_pd_); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]); + args_[MKLDNN_ARG_FROM] = *i_mem; + args_[MKLDNN_ARG_TO] = *o_mem.second; + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h index 7a00f621d452..15f43e16f02e 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h @@ -25,7 +25,7 @@ #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_INL_H_ #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include @@ -35,11 +35,11 @@ namespace mxnet { namespace op { -template +template static void MKLDNNQuantizeComputeKer(const std::vector& inputs, const std::vector& outputs, const QuantizeParam& param, - const std::vector &req) { + const std::vector& req) { using namespace mshadow; using namespace mxnet_op; using red::limits::MaxValue; @@ -60,38 +60,30 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; } float scale = quantized_range / real_range; - primitive_attr attr; + mkldnn::primitive_attr attr; const int mask = 0; std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); - NDArray in_buffer = inputs[0]; - if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) - in_buffer = inputs[0].Reorder2Default(); + if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default(); auto i_mem = in_buffer.GetMKLDNNData(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - if (i_fmt == mkldnn::memory::format::nchw || - i_fmt == mkldnn::memory::format::nChw8c || - i_fmt == mkldnn_nChw16c) { - i_fmt = mkldnn::memory::format::nhwc; - } + auto i_desc = i_mem->get_desc(); size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); + mkldnn::memory::desc o_desc; + if (i_ndim == 4) { + mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nhwc; + mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims); + o_desc = mkldnn::memory::desc(o_dims, get_mkldnn_type(), o_fmt); + } else { + o_desc = i_desc; + o_desc.data.data_type = get_mkldnn_type_t(); } - auto o_desc = mkldnn::memory::desc(i_dims, - (mkldnn::memory::data_type)data_type_enum::type, - i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second)); + auto reorder_pd = mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc, attr); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc, req[0]); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(reorder_pd), {{MKLDNN_ARG_FROM, *i_mem}, {MKLDNN_ARG_TO, *o_mem.second}}); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index 7cdce8e32bc8..e05ac4ceb3a6 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -24,7 +24,7 @@ #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_ #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include @@ -47,8 +47,8 @@ class SgMKLDNNQuantizeOperator { QuantizeV2Param param_; float cached_data_min_{0.f}; float cached_data_max_{0.f}; - std::shared_ptr i_mem_; - std::shared_ptr o_mem_; + mkldnn::memory::desc o_desc_; + mkldnn_args_map_t args_; std::shared_ptr fwd_pd_; }; @@ -127,36 +127,29 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - if (i_fmt == mkldnn::memory::format::nchw || i_fmt == mkldnn::memory::format::nChw8c || - i_fmt == mkldnn_nChw16c) { - i_fmt = mkldnn::memory::format::nhwc; - } + auto i_desc = i_mem->get_desc(); size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); + if (i_ndim == 4) { + mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nhwc; + mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims); + o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type(out_type), o_fmt); + } else { + o_desc_ = i_desc; + o_desc_.data.data_type = get_mkldnn_type_t(out_type); } - auto o_desc = mkldnn::memory::desc(i_dims, get_mkldnn_type(out_type), i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - i_mem_ = std::make_shared(i_mpd, nullptr); - o_mem_ = std::make_shared(o_mpd, nullptr); - fwd_pd_ = std::make_shared(reorder_pd, *i_mem_, *o_mem_); + auto reorder_pd = mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr); + fwd_pd_ = std::make_shared(reorder_pd); initalized_ = true; } - auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]); - i_mem_->set_data_handle(i_mem->get_data_handle()); - o_mem_->set_data_handle(o_mem.second->get_data_handle()); - MKLDNNStream::Get()->RegisterPrim(*fwd_pd_); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]); + args_[MKLDNN_ARG_FROM] = *i_mem; + args_[MKLDNN_ARG_TO] = *o_mem.second; + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc index f81071704762..c8c21ad46a6d 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc @@ -23,7 +23,7 @@ * \author Wenting Jiang, Xinyu Chen */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../../nn/mkldnn/mkldnn_base-inl.h" #include "../../nn/mkldnn/mkldnn_convolution-inl.h" #include "../../nn/convolution-inl.h" @@ -43,11 +43,12 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); NDArray weight = in_data[conv::kWeight]; ConvolutionParam param = nnvm::get(attrs.parsed); - auto &fwd = GetConvFwd( - param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], - param.no_bias ? nullptr : &in_data[conv::kBias], - out_data[conv::kOut]); - auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + MKLDNNConvFullParam full_param; + full_param.conv_param = param; + full_param.mkldnn_param.Init(std::unordered_map()); + auto &fwd = GetConvFwd(full_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); + auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.GetPd().src_desc()); const mkldnn::memory *weight_mem; // For inference, we want to reorder the weight array so we don't need to // reorder data every time. @@ -55,20 +56,22 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, // We also need to modify the layout on the original weight array. // Don't switch below sequence because naive engine will executes // pushAsync synchronously. - weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); + weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc()); + weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); } else { weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); } - auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), + auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.GetPd().dst_desc(), req[conv::kOut]); - const mkldnn::memory *bias_mem = nullptr; - if (!param.no_bias) - bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); - fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); - + mkldnn_args_map_t net_args; + if (!param.no_bias) { + const mkldnn::memory *bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.GetPd().bias_desc()); + net_args.insert({MKLDNN_ARG_BIAS, *bias_mem}); + } + net_args.insert({MKLDNN_ARG_SRC, *data_mem}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem}); + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data[conv::kOut], out_mem); MKLDNNStream::Get()->Submit(); Stream *s = ctx.get_stream(); diff --git a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h index 03d9b9067b57..79611d8bd55b 100644 --- a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h @@ -24,7 +24,7 @@ #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_REQUANTIZE_INL_H_ #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_REQUANTIZE_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include @@ -71,11 +71,10 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs, float second_scale = second_quantized_range / second_real_range; float scale = first_scale * second_scale; - primitive_attr attr; + mkldnn::primitive_attr attr; const int mask = 0; std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); NDArray in_buffer = inputs[0]; @@ -83,20 +82,13 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs, in_buffer = inputs[0].Reorder2Default(); auto i_mem = in_buffer.GetMKLDNNData(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_dim); - for (size_t i = 0; i < i_dim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); - } - auto o_desc = mkldnn::memory::desc(i_dims, - (mkldnn::memory::data_type)data_type_enum::type, - i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second)); + auto i_desc = i_mem->get_desc(); + auto o_desc = i_desc; + o_desc.data.data_type = get_mkldnn_type_t(); + auto reorder_pd = mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc, attr); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc, req[0]); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(reorder_pd), {{MKLDNN_ARG_FROM, *i_mem}, {MKLDNN_ARG_TO, *o_mem.second}}); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc index 63467506b99b..bbe728366333 100644 --- a/src/operator/quantization/quantize.cc +++ b/src/operator/quantization/quantize.cc @@ -23,7 +23,7 @@ * \brief */ #include "./quantize-inl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_quantize-inl.h" #endif @@ -37,7 +37,7 @@ bool QuantizeStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask) { *dispatch_mode = DispatchMode::kFComputeEx; } @@ -85,7 +85,7 @@ where // TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, // will be reverted after the improvement of CachedOP is done. .set_attr("FGradient", MakeZeroGradNodes) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizeCompute) #endif diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 9a30386723be..6d264bccb1db 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -24,7 +24,7 @@ */ #include "./quantize_v2-inl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_quantize_v2-inl.h" #endif @@ -36,7 +36,7 @@ static bool QuantizeV2StorageType(const nnvm::NodeAttrs& attrs, const int dev_ma DispatchMode* dispatch_mode, std::vector* in_attrs, std::vector* out_attrs) { *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask) { *dispatch_mode = DispatchMode::kFComputeEx; } @@ -54,7 +54,7 @@ static OpStatePtr CreateQuantizeV2State(const nnvm::NodeAttrs& attrs, Context ct if (ctx.dev_type == kGPU) { state = OpStatePtr::Create>(attrs); } else { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 state = OpStatePtr::Create(attrs); #else state = OpStatePtr::Create>(attrs); @@ -103,7 +103,7 @@ If min_calib_range isn't presented, the output type will be int8. // will be reverted after the improvement of CachedOP is done. .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FCreateOpState", CreateQuantizeV2State) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", SgMKLDNNQuantizeForward) #endif diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc index 9d774ddf24f1..5406c466287d 100644 --- a/src/operator/quantization/quantized_conv.cc +++ b/src/operator/quantization/quantized_conv.cc @@ -24,7 +24,7 @@ * \author Ziheng Jiang, Jun Wu */ #include "../nn/convolution-inl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../nn/mkldnn/mkldnn_ops-inl.h" #endif @@ -114,7 +114,7 @@ bool QuantizedConvStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask) { *dispatch_mode = DispatchMode::kFComputeEx; } diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc index 9ee299cf4ae9..0eb592f40dde 100644 --- a/src/operator/quantization/requantize.cc +++ b/src/operator/quantization/requantize.cc @@ -24,7 +24,7 @@ */ #include "./requantize-inl.h" #include "./quantize-inl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_requantize-inl.h" #endif @@ -38,7 +38,7 @@ bool RequantizeStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask) { *dispatch_mode = DispatchMode::kFComputeEx; } @@ -71,7 +71,7 @@ inference accuracy. // TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, // will be reverted after the improvement of CachedOP is done. .set_attr("FGradient", MakeZeroGradNodes) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNRequantizeForward) #else diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h index fcf767adebad..89bb80fe8b60 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h @@ -19,7 +19,7 @@ #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include @@ -47,10 +47,10 @@ static inline bool IsOutputUInt8(const MKLDNNConvFusionParam& param) { act_alg == mkldnn::algorithm::eltwise_bounded_relu); }; if ((!mkldnn_param.with_sum) && mkldnn_param.with_act) { - CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::algorithm_undef); + CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::undef); result = IsOutputUInt8Helper(param.full_conv_param.act_param.alg); } else if (mkldnn_param.with_postsum_act) { - CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::algorithm_undef); + CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::undef); result = IsOutputUInt8Helper(param.full_conv_param.postsum_act_param.alg); } return result; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index d9bfa02b8820..04c738ddac11 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -17,7 +17,7 @@ * under the License. */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include @@ -128,40 +128,37 @@ static void ConvertWeightBias2MKLDNN(const MKLDNNConvFullParam ¶m, NDArray *weight, NDArray *bias, bool has_bias, float data_scale, const std::vector &weight_scales) { MKLDNNStream *stream = MKLDNNStream::Get(); - const auto new_weight = NDArray(fwd_pd.weights_primitive_desc()); + const auto new_weight = NDArray(fwd_pd.weights_desc()); const auto conv_weights_memory = new_weight.GetMKLDNNData(); - primitive_attr weight_attr; + mkldnn::primitive_attr weight_attr; if (weight_scales.size()) { const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1; - weight_attr.set_int_output_round_mode(round_mode::round_nearest); weight_attr.set_output_scales(weight_mask, weight_scales); } auto default_weights_memory = GetWeights(*weight, param.conv_param.num_group); if (default_weights_memory == nullptr) default_weights_memory = weight->GetMKLDNNData(); const auto weight_reorder_pd = - mkldnn::reorder::primitive_desc(default_weights_memory->get_primitive_desc(), - conv_weights_memory->get_primitive_desc(), weight_attr); - stream->RegisterPrim( - mkldnn::reorder(weight_reorder_pd, *default_weights_memory, *conv_weights_memory)); - + mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(weight_reorder_pd), + {{MKLDNN_ARG_FROM, *default_weights_memory}, {MKLDNN_ARG_TO, *conv_weights_memory}}); NDArray new_bias; if (has_bias && data_scale) { std::vector bias_scales(weight_scales.size()); for (size_t c = 0; c < weight_scales.size(); ++c) { bias_scales[c] = weight_scales[c] * data_scale; } - new_bias = NDArray(fwd_pd.bias_primitive_desc()); + new_bias = NDArray(fwd_pd.bias_desc()); const auto conv_bias_memory = new_bias.GetMKLDNNData(); const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1; - primitive_attr bias_attr; - bias_attr.set_int_output_round_mode(round_mode::round_nearest); + mkldnn::primitive_attr bias_attr; bias_attr.set_output_scales(bias_mask, bias_scales); auto bias_weights_memory = bias->GetMKLDNNData(); - auto bias_reorder_pd = - mkldnn::reorder::primitive_desc(bias_weights_memory->get_primitive_desc(), - conv_bias_memory->get_primitive_desc(), bias_attr); - stream->RegisterPrim( - mkldnn::reorder(bias_reorder_pd, *bias_weights_memory, *conv_bias_memory)); + const auto bias_reorder_pd = + mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(bias_reorder_pd), + {{MKLDNN_ARG_FROM, *bias_weights_memory}, {MKLDNN_ARG_TO, *conv_bias_memory}}); } stream->Submit(); *weight = new_weight; @@ -186,6 +183,7 @@ class SgMKLDNNConvOperator { nnvm::Symbol subgraph_sym_; MKLDNNConvFusionParam param_; std::shared_ptr fwd_; + mkldnn_args_map_t args_; NDArray cached_weight_; NDArray cached_bias_; float cached_data_min_; @@ -253,22 +251,24 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); if (outputs[kOut].dtype() == mshadow::kInt32) { - auto mem_desc = in_mkl_mem->get_primitive_desc().desc(); - auto this_dtype = get_mkldnn_type(mshadow::kInt32); - mkldnn::memory::desc omd( - mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), - this_dtype, static_cast(mem_desc.data.format)); - mkldnn::memory::primitive_desc opd(omd, CpuEngine::Get()->get_engine()); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(opd, out_mkl_mem->get_data_handle())); + const auto& mem_desc = in_mkl_mem->get_desc(); + const auto this_dtype = get_mkldnn_type(mshadow::kInt32); + auto omd = mem_desc; + omd.data.data_type = static_cast(this_dtype); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(omd, CpuEngine::Get()->get_engine(), + out_mkl_mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(tmp_mem); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*in_mkl_mem, *tmp_mem)); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(*in_mkl_mem, *tmp_mem), + {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}}); output = NDArray(tmp_mem); } else { - mkldnn_mem_ptr tmp_mem( - new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle())); - MKLDNNStream::Get()->RegisterMem(tmp_mem); - mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get()); - output = NDArray(tmp_mem); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(in_mkl_mem->get_desc(), + CpuEngine::Get()->get_engine(), + out_mkl_mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(tmp_mem); + mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get()); + output = NDArray(tmp_mem); } } } @@ -391,27 +391,25 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, fwd_.reset(new MKLDNNConvForward( full_conv_param, ctx.is_train, data, cached_weight_, has_bias ? &cached_bias_ : nullptr, output)); - ConvertWeightBias2MKLDNN(full_conv_param, fwd_->fwd_pd, &cached_weight_, &cached_bias_, + ConvertWeightBias2MKLDNN(full_conv_param, fwd_->GetPd(), &cached_weight_, &cached_bias_, has_bias, data_scale_, weight_scales_); - fwd_->SetNewMem(*data.GetMKLDNNData(), *cached_weight_.GetMKLDNNData(), - has_bias ? cached_bias_.GetMKLDNNData() : nullptr, - *output.GetMKLDNNData()); + args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData(); + if (has_bias) args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData(); + args_[MKLDNN_ARG_DST] = *output.GetMKLDNNData(); initialized_ = true; } if (mkldnn_param.with_sum) { - const auto output_mem = output.GetMKLDNNData(); - const auto out_mem_desc = output_mem->get_primitive_desc().desc(); - const auto dst_format = fwd_->fwd_pd.dst_primitive_desc().desc().data.format; - if (out_mem_desc.data.format != dst_format) { - auto tmp_out_mem = output.GetMKLDNNDataReorder(fwd_->fwd_pd.dst_primitive_desc()); - mkldnn::memory::desc data_md( - mkldnn::memory::dims(out_mem_desc.data.dims, - out_mem_desc.data.dims + out_mem_desc.data.ndims), - static_cast(out_mem_desc.data.data_type), - static_cast(dst_format)); - mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); - mkldnn_mem_ptr new_out_mem(new mkldnn::memory(pd, output_mem->get_data_handle())); + const auto& output_mem = output.GetMKLDNNData(); + const auto& out_mem_desc = output_mem->get_desc(); + const auto& dst_mem_desc = fwd_->GetPd().dst_desc(); + if (out_mem_desc != dst_mem_desc) { + auto tmp_out_mem = output.GetMKLDNNDataReorder(fwd_->GetPd().dst_desc()); + auto data_md = dst_mem_desc; + data_md.data.data_type = static_cast(out_mem_desc.data.data_type); + mkldnn_mem_ptr new_out_mem(new mkldnn::memory(data_md, CpuEngine::Get()->get_engine(), + output_mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(new_out_mem); mxnet::MKLDNNCopy(*tmp_out_mem, new_out_mem.get()); output = NDArray(new_out_mem); @@ -419,10 +417,11 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } if (mkldnn_param.quantized) { - auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_primitive_desc()); - mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->fwd_pd.dst_primitive_desc()); - fwd_->SetNewMem(*data_mem, *mem); - MKLDNNStream::Get()->RegisterPrim(fwd_->GetFwd()); + auto data_mem = data.GetMKLDNNDataReorder(fwd_->GetPd().src_desc()); + mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->GetPd().dst_desc()); + args_[MKLDNN_ARG_SRC] = *data_mem; + args_[MKLDNN_ARG_DST] = *mem; + MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_); MKLDNNStream::Get()->Submit(); } else { std::vector new_inputs; @@ -441,9 +440,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } if (mkldnn_param.with_sum) { auto out = const_cast(outputs[kOut]); - auto format = static_cast( - fwd_->fwd_pd.dst_primitive_desc().desc().data.format); - out.UpdateMKLDNNMemDesc(format); + out.UpdateMKLDNNMemDesc(fwd_->GetPd().dst_desc()); } } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h index 40b3f7c1d010..8e5eb630ae76 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -19,7 +19,7 @@ #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include diff --git a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h index 38b08968d8a5..19706fe36e32 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h @@ -18,7 +18,7 @@ */ #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_POST_QUANTIZE_PROPERTY_H_ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_POST_QUANTIZE_PROPERTY_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index d0d2b51918b1..340dfe3b01ae 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -17,7 +17,7 @@ * under the License. */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "mkldnn_conv_property.h" #include "mkldnn_fc_property.h" @@ -34,25 +34,33 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN) MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); +#endif // MXNET_USE_MKLDNN == 100 +#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); - - +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE) .set_attr("context", Context::CPU()); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty) .set_attr("quantize", true); +#endif // MXNET_USE_MKLDNN == 100 +#if MXNET_USE_MKLDNN == 1 + MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty) .set_attr("quantize", true); - +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty); +#endif // MXNET_USE_MKLDNN == 100 +#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty); - +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 } // namespace op } // namespace mxnet - -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 From 72111ea9f9e6da2bbfa1260e29fa3f9f1108b2ab Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 26 Sep 2019 10:24:31 +0800 Subject: [PATCH 2/3] Fix lint --- src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h | 3 ++- src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc | 3 ++- src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index e05ac4ceb3a6..0f695a0c59a2 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -142,7 +142,8 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector(reorder_pd); initalized_ = true; } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc index c8c21ad46a6d..6a1835651b6f 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc @@ -65,7 +65,8 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, req[conv::kOut]); mkldnn_args_map_t net_args; if (!param.no_bias) { - const mkldnn::memory *bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.GetPd().bias_desc()); + const mkldnn::memory *bias_mem = + in_data[conv::kBias].GetMKLDNNDataReorder(fwd.GetPd().bias_desc()); net_args.insert({MKLDNN_ARG_BIAS, *bias_mem}); } net_args.insert({MKLDNN_ARG_SRC, *data_mem}); diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index 340dfe3b01ae..95d2b7d86147 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -63,4 +63,4 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScale #if MXNET_USE_MKLDNN == 100 } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 100 +#endif // MXNET_USE_MKLDNN == 100 From be2c2789c90f8da89e5528d80468fe83b510f77a Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 26 Sep 2019 10:31:28 +0800 Subject: [PATCH 3/3] Fix clang build Change-Id: I9468774d014c852901e4cc3bffabd8a3d8004519 --- src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h index 4c8a7ab285b3..3766fbe016e5 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h @@ -18,7 +18,7 @@ */ #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../subgraph_property.h"