diff --git a/src/operator/subgraph/mkldnn/mkldnn_common.h b/src/operator/subgraph/mkldnn/mkldnn_common.h index 87ddc438d846..2d1d66fbccaa 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_common.h +++ b/src/operator/subgraph/mkldnn/mkldnn_common.h @@ -87,7 +87,7 @@ static std::vector GetWeightScales(const NDArray &weight, const NDArray * } static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias, - const mkldnn::memory::desc weight_md, + const mkldnn::memory::desc &weight_md, const mkldnn::memory::desc *bias_md, const int num_group, float data_scale, const std::vector &weight_scales, diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index ec8ba640c136..e2b1807b6559 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -63,10 +63,13 @@ class SgMKLDNNFCOp { private: bool initialized_{false}; + bool channel_wise_runtime_{false}; + bool reorder_data_{false}; nnvm::Symbol subgraph_sym_; MKLDNNFCFullParam full_param_; mkldnn_args_map_t args_; std::shared_ptr fwd_; + std::shared_ptr cached_data_mem_; std::shared_ptr cached_out_mem_; NDArray cached_weight_; NDArray cached_bias_; @@ -82,28 +85,10 @@ class SgMKLDNNFCOp { float cached_max_output_; float data_scale_{0.0f}; std::vector weight_scales_; + size_t total_num_inputs_; + size_t total_num_outputs_; }; -static inline void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, - NDArray *in_data) { - const mxnet::TShape ishape = in_data->shape(); - - // If the input data is a view of an MKLDNN array, we should create a new - // NDArray with reordered data. - if (in_data->IsMKLDNNData() && in_data->IsView()) - *in_data = in_data->Reorder2Default(); - - auto data_ndim = ishape.ndim(); - if (data_ndim != 2) { - if (!param.flatten) { - *in_data = in_data->MKLDNNDataReshape( - Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1])); - } else { - *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim))); - } - } -} - void SgMKLDNNFCOp::Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -112,9 +97,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, auto &default_param = full_param_.default_param; bool has_bias = !default_param.no_bias; size_t base_num_inputs = has_bias ? 3 : 2; - size_t total_num_inputs = base_num_inputs; size_t base_num_outputs = 1; - size_t total_num_outputs = base_num_outputs; float min_data = 0.0f; float max_data = 0.0f; @@ -123,17 +106,29 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, float min_bias = 0.0f; float max_bias = 0.0f; - bool channel_wise = false; - if (mkldnn_param.channel_wise_quantize.has_value() && - mkldnn_param.channel_wise_quantize) { - channel_wise = true; + if (!initialized_) { + if (mkldnn_param.channel_wise_quantize.has_value() && + mkldnn_param.channel_wise_quantize) { + channel_wise_runtime_ = true; + } + + total_num_inputs_ = base_num_inputs; + total_num_outputs_ = base_num_outputs; + if (mkldnn_param.quantized) { + total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : (base_num_inputs * 3); + total_num_outputs_ = + mkldnn_param.enable_float_output ? base_num_outputs : (base_num_outputs * 3); + } } + CHECK_EQ(in_data.size(), total_num_inputs_); + CHECK_EQ(out_data.size(), total_num_outputs_); + + NDArray data = in_data[fullc::kData]; + const NDArray &weight = in_data[fullc::kWeight]; + const NDArray &output = out_data[fullc::kOut]; if (mkldnn_param.quantized) { - if (channel_wise) { - total_num_inputs = base_num_inputs + 2; - } else { - total_num_inputs = base_num_inputs * 3; + if (!channel_wise_runtime_) { min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr()[0]; max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr()[0]; if (has_bias) { @@ -143,20 +138,11 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, } min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr()[0]; max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr()[0]; - if (!mkldnn_param.enable_float_output) { - total_num_outputs = base_num_outputs * 3; - } } - CHECK_EQ(in_data.size(), total_num_inputs); - CHECK_EQ(out_data.size(), total_num_outputs); - - NDArray data = in_data[fullc::kData]; - NDArray weight = in_data[fullc::kWeight]; - NDArray output = out_data[fullc::kOut]; - MKLDNNFCFlattenData(default_param, &data); - if (initialized_ && mkldnn_param.quantized) { - if (channel_wise) { + if (initialized_ && mkldnn_param.quantized && + dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) { + if (channel_wise_runtime_) { if (cached_min_data_ != min_data || cached_max_data_ != max_data || weight_ver_ != weight.version() || (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) { @@ -173,6 +159,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, if (!initialized_) { const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const auto engine = CpuEngine::Get()->get_engine(); cached_min_data_ = min_data; cached_max_data_ = max_data; cached_min_weight_ = min_weight; @@ -187,9 +174,22 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, } else { cached_bias_ = NDArray(); } + const mxnet::TShape ishape = data.shape(); + const auto data_ndim = ishape.ndim(); + if (data.IsMKLDNNData()) { + reorder_data_ = true; + data = data.Reorder2Default(); + } + if (data_ndim != 2) { + if (!default_param.flatten) { + data = data.MKLDNNDataReshape( + Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1])); + } else { + data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim))); + } + } // create cached out_md - const mxnet::TShape ishape = data.shape(); const mxnet::TShape oshape = output.shape(); mkldnn::memory::dims out_dims(2); if (oshape.ndim() == 2) { @@ -206,7 +206,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, } mkldnn::memory::desc out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(output.dtype()), static_cast(GetDefaultFormat(2))); - cached_out_mem_ = std::make_shared(out_md, CpuEngine::Get()->get_engine()); + cached_out_mem_ = std::make_shared(out_md, engine); bool support_channelwise_scale = false; if (mkldnn_param.quantized) { @@ -229,7 +229,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, // True True True // True False Error // False True/False False - if (channel_wise && !support_channelwise_scale) { + if (channel_wise_runtime_ && !support_channelwise_scale) { LOG(FATAL) << "Currently, channel-wise quantization requires fuse requantize or dequantize." << " Please make sure the `min_calib_range` and `max_calib_range` are set when only" @@ -237,7 +237,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and " << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)"; } - support_channelwise_scale = support_channelwise_scale && channel_wise; + support_channelwise_scale = support_channelwise_scale && channel_wise_runtime_; if (support_channelwise_scale) { MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { @@ -329,17 +329,22 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, has_bias ? &bias_md : nullptr, 1, data_scale_, weight_scales_, false); } else { - cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc()); - auto cached_weight_mem = cached_weight_.GetMKLDNNData(); - auto def_weight_mem = weight.GetMKLDNNData(); - std::unordered_map args( - {{MKLDNN_ARG_FROM, *def_weight_mem}, - {MKLDNN_ARG_TO, *cached_weight_mem}}); - MKLDNNStream::Get()->RegisterPrimArgs( - mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args); + const auto def_weight_mem = weight.GetMKLDNNData(); + if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) { + cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc()); + auto cached_weight_mem = cached_weight_.GetMKLDNNData(); + std::unordered_map args( + {{MKLDNN_ARG_FROM, *def_weight_mem}, + {MKLDNN_ARG_TO, *cached_weight_mem}}); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args); + } } - args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + const auto data_mem = data.GetMKLDNNData(); + cached_data_mem_ = std::make_shared(data_mem->get_desc(), engine); + + args_[MKLDNN_ARG_SRC] = *cached_data_mem_; args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData(); if (has_bias) args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData(); @@ -347,12 +352,15 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, initialized_ = true; } - auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_desc()); + if (reorder_data_) { + data = data.Reorder2Default(); + } + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + cached_data_mem_->set_data_handle(reinterpret_cast(data.data().dptr())); + }); MSHADOW_TYPE_SWITCH(output.dtype(), DType, { cached_out_mem_->set_data_handle(reinterpret_cast(output.data().dptr())); }); - args_[MKLDNN_ARG_SRC] = *data_mem; - args_[MKLDNN_ARG_DST] = *cached_out_mem_; MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_); MKLDNNStream::Get()->Submit();