diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index bc61bc7dd926..f0c199bc16f0 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -541,8 +541,18 @@ inline void PushOperator(const OpStatePtr& state, // copying A to B may not happen, and will corrupt A's memory. InvalidateOutputs(outputs, req); } + // add for mkldnn OP + no mkldnn OP + const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); + if (!is_mkldnn.get(attrs.op, false)) { + std::vector inputs_fallback; + CreateDefaultInputs(inputs, &inputs_fallback); + fcompute_ex(state, opctx, inputs_fallback, req, outputs); + } else { +#endif + fcompute_ex(state, opctx, inputs, req, outputs); +#if MXNET_USE_MKLDNN == 100 + } #endif - fcompute_ex(state, opctx, inputs, req, outputs); if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && rctx.get_stream() && !rctx.is_bulk) { rctx.get_stream()->Wait(); diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 06ad6d034398..c80c08e0a381 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -97,7 +97,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs, valid_bias = inputs[2].storage_type() == kDefaultStorage || inputs[2].storage_type() == kRowSparseStorage; } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (common::ContainsOnlyStorage(inputs, kDefaultStorage) && common::ContainsOnlyStorage(outputs, kDefaultStorage)) { if (SupportMKLDNNFC(inputs[0])) { @@ -141,7 +141,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs, #endif } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, @@ -199,7 +199,7 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs, dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (!MKLDNNEnvSet()) *dispatch_mode = DispatchMode::kFComputeFallback; #endif @@ -233,7 +233,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (!MKLDNNEnvSet()) *dispatch_mode = DispatchMode::kFComputeFallback; #endif @@ -295,7 +295,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored. [](const NodeAttrs& attrs) { return std::vector{"output"}; }) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; @@ -326,7 +326,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected) }) .set_attr("FInferStorageType", BackwardFCStorageType) .set_attr_parser(ParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", FullyConnectedGradComputeExCPU) #endif diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h index fddaedc2459d..db8cfdc986ec 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h @@ -27,7 +27,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include @@ -50,7 +50,7 @@ struct MKLDNNFCParam: public dmlc::Parameter { DMLC_DECLARE_FIELD(enable_float_output).set_default(false) .describe("Whether to enable float32 output"); DMLC_DECLARE_FIELD(with_eltwise).set_default(false) - .describe("Whether there's a post elemwise after FullyConnected operator"); + .describe("Whether there's a post with_eltwise after FullyConnected operator"); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) .describe("The minimum scalar value in the form of float32 obtained " @@ -85,10 +85,9 @@ class MKLDNNFullyConnectedForward { const NDArray &data, const NDArray &weight, const NDArray *bias, const mkldnn::memory::desc &out_md) - : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {} - - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output); + : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) { + fwd_ = std::make_shared(fwd_pd); + } const mkldnn::inner_product_forward &GetFwd() const { return *fwd_; @@ -96,10 +95,6 @@ class MKLDNNFullyConnectedForward { private: std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr weight_; - std::shared_ptr bias_; - std::shared_ptr out_; }; typedef ParamOpSign MKLDNNFullyconSignature; diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index fbe37e227cd1..80eb2d6727bd 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -24,7 +24,7 @@ * \author Da Zheng, Ciyong Chen */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "mkldnn_fully_connected-inl.h" namespace mxnet { @@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( } attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); } } @@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei } } -void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data, - const mkldnn::memory &weight, - const mkldnn::memory *bias, - const mkldnn::memory &output) { - if (this->data_ == nullptr) - this->data_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data_->set_data_handle(data.get_data_handle()); - - if (this->weight_ == nullptr) - this->weight_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight_->set_data_handle(weight.get_data_handle()); - - if (this->out_ == nullptr) - this->out_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (bias != nullptr) { - if (this->bias_ == nullptr) - this->bias_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.bias_primitive_desc(), bias->get_data_handle())); - else - this->bias_->set_data_handle(bias->get_data_handle()); - - if (this->fwd_ == nullptr) - this->fwd_ = std::shared_ptr( - new mkldnn::inner_product_forward( - fwd_pd, mkldnn::primitive::at(*this->data_), - mkldnn::primitive::at(*this->weight_), - mkldnn::primitive::at(*this->bias_), *this->out_)); - } else { - if (this->fwd_ == nullptr) { - this->fwd_ = std::shared_ptr( - new mkldnn::inner_product_forward( - fwd_pd, mkldnn::primitive::at(*this->data_), - mkldnn::primitive::at(*this->weight_), *this->out_)); - } - } -} - MKLDNNFullyConnectedForward &GetFCFwd( const FullyConnectedParam ¶m, const bool is_train, const NDArray &data, const NDArray &weight, @@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, mkldnn::memory::dims out_dims{static_cast(oshape.ProdShape(0, oshape.ndim()-1)), static_cast(oshape[ishape.ndim()-1])}; *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()), - mkldnn::memory::format::any); + mkldnn::memory::format_tag::any); } else { *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); mkldnn::memory::dims out_dims{static_cast(oshape[0]), static_cast(oshape.ProdShape(1, oshape.ndim()))}; *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()), - mkldnn::memory::format::any); + mkldnn::memory::format_tag::any); } } } @@ -244,35 +198,35 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param, NDArray weight = in_data[fullc::kWeight]; NDArray data = in_data[fullc::kData]; - auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc()); + auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { if (weight.IsMKLDNNData()) { weight.Reorder2DefaultAsync(); } - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1); } else { - if (weight.IsDefaultData()) { - // We also need to modify the layout on the original weight array. - // Don't switch below sequence because naive engine will executes - // pushAsync synchronously. - weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); - } else { - weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); + weight_mem = weight.GetMKLDNNData(); + if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) { + // TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug + // weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc()); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1); } } auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], - fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data); + fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data); + + std::unordered_map args = { + {MKLDNN_ARG_SRC, *data_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DST, *out_mem.second}, + }; if (!full_param.default_param.no_bias) { auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder( - fwd->fwd_pd.bias_primitive_desc()); - fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - } else { - fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); + fwd->fwd_pd.bias_desc()); + args.insert({ MKLDNN_ARG_BIAS, *bias_mem}); } - MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd()); + MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args); CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Get()->Submit(); } @@ -339,13 +293,18 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData( data, weight, out_grad, fwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdData_pd.diff_dst_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); + ipBwdData_pd.diff_dst_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc()); auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - ipBwdData_pd.diff_src_primitive_desc(), + ipBwdData_pd.diff_src_desc(), req[fullc::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( - ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); + std::unordered_map args = { + {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second} + }; + + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args); CommitOutput(in_grad[fullc::kData], in_grad_mem); } if (req[fullc::kWeight]) { @@ -353,23 +312,26 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, = GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], out_grad, fwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); + ipBwdWeights_pd.diff_dst_desc()); + auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc()); auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], - ipBwdWeights_pd.diff_weights_primitive_desc(), + ipBwdWeights_pd.diff_weights_desc(), req[fullc::kWeight]); + std::unordered_map args = { + {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_SRC, *data_mem}, + {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}, + }; + mkldnn_output_t in_grad_bias; - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); - } else { + if (!param.no_bias) { in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], - ipBwdWeights_pd.diff_bias_primitive_desc(), + ipBwdWeights_pd.diff_bias_desc(), req[fullc::kBias]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, - *in_grad_bias.second)); + args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second}); } + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args); CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } @@ -378,4 +340,4 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 951b0754fff9..20d80cd3b3e5 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -44,16 +44,6 @@ namespace mxnet { namespace op { #if MXNET_USE_MKLDNN == 1 -/* For fully connected. */ -void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data); -void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); - /* For deconvolution */ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, @@ -104,6 +94,16 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, #endif #if MXNET_USE_MKLDNN == 100 +/* For fully connected. */ +void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); +void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + /* For convolution. */ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data,