From 9f56e2b87edc367e883fc7f8df3287ef68ca998f Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Fri, 20 Sep 2019 15:05:45 +0800 Subject: [PATCH] add mkldnn lrn --- src/operator/nn/lrn.cc | 12 +- src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 192 +++++++++--------------- 2 files changed, 76 insertions(+), 128 deletions(-) diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 3a3ca59f2be1..e729f907f51e 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -26,7 +26,7 @@ #include "./lrn-inl.h" #include "../operator_common.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_lrn-inl.h" #include "./mkldnn/mkldnn_base-inl.h" #endif @@ -82,7 +82,7 @@ struct LRNGrad { } }; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -169,7 +169,7 @@ number of kernels in the layer. .set_attr_parser(ParamParser) .set_attr("FInferShape", LRNShape) .set_attr("FInferType", LRNType) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", LRNForwardInferStorageType) #endif .set_attr("FListInputNames", @@ -181,7 +181,7 @@ number of kernels in the layer. return std::vector{"output", "tmp_norm"}; }) .set_attr("FCompute", LRNCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", LRNComputeExCPU) #endif @@ -192,11 +192,11 @@ number of kernels in the layer. NNVM_REGISTER_OP(_backward_LRN) .set_num_outputs(1) .set_attr_parser(ParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", LRNBackwardInferStorageType) #endif .set_attr("TIsBackward", true) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", LRNGradComputeExCPU) // Native compute requires norm while MKLDNN does not so cannot be compared in debug mode diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index 31b293a14c2c..8436339d5c13 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -25,7 +25,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include "../lrn-inl.h" @@ -34,27 +34,27 @@ namespace mxnet { namespace op { -inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { +inline mkldnn::algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { // TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward // Need to confirm with MKLDNN team and fix later - return algorithm::lrn_across_channels; + return mkldnn::algorithm::lrn_across_channels; } inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc( - const LRNParam ¶m, const bool is_train, const memory::desc &src_md) { + const LRNParam ¶m, const bool is_train, const mkldnn::memory::desc &src_md) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); - const algorithm alg = GetMKLDNNLRNAlgo(param); + const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; const float beta = param.beta; const int nsize = param.nsize; const float k = param.knorm; - auto kind = prop_kind::forward_training; + auto kind = mkldnn::prop_kind::forward_training; if (is_train) { - kind = prop_kind::forward_training; + kind = mkldnn::prop_kind::forward_training; } else { - kind = prop_kind::forward_scoring; + kind = mkldnn::prop_kind::forward_scoring; } - lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k); + mkldnn::lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k); return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine); } @@ -63,13 +63,13 @@ inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc( const mkldnn::memory::desc &diff_md, const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); - const algorithm alg = GetMKLDNNLRNAlgo(param); + const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; const float beta = param.beta; const int nsize = param.nsize; const float k = param.knorm; - lrn_backward::desc lrnBwd_desc(alg, data_in_md, + mkldnn::lrn_backward::desc lrnBwd_desc(alg, data_in_md, diff_md, nsize, alpha, beta, k); return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc, engine, lrnFwd_desc); @@ -83,33 +83,24 @@ class MKLDNNLRNFwd { public: MKLDNNLRNFwd(const LRNParam& param, bool is_train, - const NDArray &in_data): - is_train(is_train) { + const NDArray &in_data) { _Init(param, is_train, in_data); } ~MKLDNNLRNFwd() {} - void SetNewMem(const NDArray &data, - const NDArray &output, - const OpReqType req); - - void SetNewMem(const NDArray &in_data, - const mkldnn::memory *out_mem); - - void Execute(const NDArray &out_data); + void Execute(const OpContext &ctx, + const NDArray &in_data, + const OpReqType req, + const NDArray &out_data); mkldnn::lrn_forward &GetFwd(); - const mkldnn::memory *GetWs(); + mkldnn::lrn_forward::primitive_desc &GetFwdPd(); private: std::shared_ptr fwd; - std::shared_ptr in_mem; - std::shared_ptr out_mem; - std::shared_ptr ws_mem; - mkldnn_output_t output_mem_t; - bool is_train; + mkldnn::lrn_forward::primitive_desc fwd_pd; private: void _Init(const LRNParam ¶m, bool is_train, const NDArray &in_data); @@ -119,52 +110,37 @@ void MKLDNNLRNFwd::_Init(const LRNParam ¶m, bool is_train, const NDArray &in_data) { mkldnn::memory::desc in_data_md = - in_data.GetMKLDNNData()->get_primitive_desc().desc(); - mkldnn::lrn_forward::primitive_desc fwd_pd = + in_data.GetMKLDNNData()->get_desc(); + this->fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md); - this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData() - ->get_primitive_desc())); - this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc())); - if (is_train) { - // If it's training, we have to create a workspace memory. Otherwise, - // MKLDNN will have segmentation fault. - ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc())); - this->fwd = std::shared_ptr( - new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem), - *this->ws_mem, *this->out_mem)); - } else { - this->fwd = std::shared_ptr( - new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)), - *(this->out_mem))); - } -} - -void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data, - const NDArray &out_data, - const OpReqType req) { - const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData(); - output_mem_t = CreateMKLDNNMem(out_data, this->out_mem->get_primitive_desc(), req); - this->in_mem->set_data_handle(in_data_mem->get_data_handle()); - this->out_mem->set_data_handle(output_mem_t.second->get_data_handle()); + this->fwd = std::shared_ptr(new mkldnn::lrn_forward(this->fwd_pd)); } -void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data, - const mkldnn::memory *out_mem) { - const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData(); - this->in_mem->set_data_handle(in_data_mem->get_data_handle()); - this->out_mem->set_data_handle(out_mem->get_data_handle()); -} - -void MKLDNNLRNFwd::Execute(const NDArray &out_data) { - MKLDNNStream::Get()->RegisterPrim(*(this->fwd)); +void MKLDNNLRNFwd::Execute(const OpContext &ctx, + const NDArray &in_data, + const OpReqType req, + const NDArray &out_data) { + auto output_mem_t = CreateMKLDNNMem(out_data, (this->fwd_pd).dst_desc(), req); + + mkldnn_args_map_t args = { + { MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()}, + { MKLDNN_ARG_DST, *output_mem_t.second }, + }; + std::shared_ptr workspace; + if (ctx.is_train) { + auto engine = CpuEngine::Get()->get_engine(); + workspace = std::make_shared((this->fwd_pd).workspace_desc(), engine); + args[MKLDNN_ARG_WORKSPACE] = *(workspace); + } + MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd), args); CommitOutput(out_data, output_mem_t); MKLDNNStream::Get()->Submit(); } mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; } +mkldnn::lrn_forward::primitive_desc &MKLDNNLRNFwd::GetFwdPd() { return this->fwd_pd; } -const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); } // End of LRN Class and its functions static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, @@ -180,10 +156,11 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, OpHash> lrn_fwds; #endif auto kind_ = - ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring; + ctx.is_train ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; MKLDNNLRNSignature key(param); - key.AddSign(kind_); + key.AddSign(static_cast(kind_)); key.AddSign(in_data); auto it = lrn_fwds.find(key); @@ -201,17 +178,12 @@ void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m, if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = in_buffer.Reorder2Default(); MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer); - fwd.SetNewMem(in_buffer, out_data, req); - fwd.Execute(out_data); + fwd.Execute(ctx, in_buffer, req, out_data); } // LRN Backward Class class MKLDNNLRNBwd { std::shared_ptr bwd; - std::shared_ptr in_data_mem; - std::shared_ptr diff_dst_mem; - std::shared_ptr ws_mem; - std::shared_ptr diff_src_mem; public: const mkldnn::lrn_forward::primitive_desc fwd_pd; @@ -222,40 +194,26 @@ class MKLDNNLRNBwd { MKLDNNLRNBwd(const LRNParam ¶m, const mkldnn::memory::desc in_data_md, const mkldnn::memory::desc diff_md) : fwd_pd(GetLRNFwdDesc(param, true, in_data_md)), - bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {} - - void SetNewMem(const NDArray &in_data, const NDArray &out_grad, - const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) { - if (bwd == nullptr) { - this->in_data_mem.reset( - new mkldnn::memory(this->fwd_pd.src_primitive_desc(), - in_data.GetMKLDNNData()->get_data_handle())); - this->diff_dst_mem.reset( - new mkldnn::memory(this->fwd_pd.dst_primitive_desc(), - out_grad.GetMKLDNNData()->get_data_handle())); - this->ws_mem.reset( - new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(), - ws->get_data_handle())); - this->diff_src_mem.reset( - new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(), - diff_src_mem->get_data_handle())); - this->bwd.reset(new mkldnn::lrn_backward( - this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem), - mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem, - *this->diff_src_mem)); - } else { - this->in_data_mem->set_data_handle( - in_data.GetMKLDNNData()->get_data_handle()); - this->diff_dst_mem->set_data_handle( - out_grad.GetMKLDNNData()->get_data_handle()); - this->ws_mem->set_data_handle(ws->get_data_handle()); - this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle()); - } - } - - void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) { - MKLDNNStream::Get()->RegisterPrim(*(this->bwd)); - CommitOutput(in_grad, diff_src_mem_); + bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) { + bwd = std::make_shared(bwd_pd); + } + + const mkldnn::lrn_backward &GetBwd() const { return *bwd; } + + void Execute(const NDArray &out_grad, + const NDArray &in_data, + const NDArray &in_grad, + const mkldnn_output_t &diff_src_mem) { + auto engine = CpuEngine::Get()->get_engine(); + auto workspace = std::make_shared((this->fwd_pd).workspace_desc(), engine); + mkldnn_args_map_t args = { + { MKLDNN_ARG_SRC, *in_data.GetMKLDNNData() }, + { MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData()}, + { MKLDNN_ARG_WORKSPACE, *workspace }, + { MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second } + }; + MKLDNNStream::Get()->RegisterPrimArgs(*(this->bwd), args); + CommitOutput(in_grad, diff_src_mem); MKLDNNStream::Get()->Submit(); } }; // End of LRN Class @@ -277,9 +235,9 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam ¶m, const NDArray &in_data, auto it = lrn_bwds.find(key); if (it == lrn_bwds.end()) { const mkldnn::memory::desc in_data_md = - in_data.GetMKLDNNData()->get_primitive_desc().desc(); + in_data.GetMKLDNNData()->get_desc(); const mkldnn::memory::desc diff_md = - out_grad.GetMKLDNNData()->get_primitive_desc().desc(); + out_grad.GetMKLDNNData()->get_desc(); MKLDNNLRNBwd bwd(param, in_data_md, diff_md); it = AddToCache(&lrn_bwds, key, bwd); } @@ -300,23 +258,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, in_buffer = in_data.Reorder2Default(); } MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad); - // Repeat FW for getting workspace - // TODO(Patric): To keep the function stateless, we can't pass workspace - // from LRN forward to backward. We have to re-compute - // LRN forward to get the workspace. - // Will refine this code later. - MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer); - std::shared_ptr dst_temp( - new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc())); - fwd.SetNewMem(in_buffer, dst_temp.get()); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); - mkldnn_output_t diff_src_mem = - CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req); - bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second); - bwd.Execute(in_grad, diff_src_mem); + CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req); + + bwd.Execute(out_grad, in_buffer, in_grad, diff_src_mem); } } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__ +