From f3f09b729b6163f5278999a8352f621fa1f4e481 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 13 Jun 2018 12:17:50 +0800 Subject: [PATCH 01/18] Enable primitive allocation cache for _backward_LRN. Change-Id: Iefe9f720de719ec2e2f5d24a006602425136711b --- src/operator/nn/lrn.cc | 2 + src/operator/nn/mkldnn/mkldnn_base-inl.h | 5 + src/operator/nn/mkldnn/mkldnn_base.cc | 11 ++ src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 185 ++++++++++++++++------- 4 files changed, 147 insertions(+), 56 deletions(-) diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 6b3d7c818378..6903f12f37b9 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -129,6 +129,8 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs, MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs); MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]); MKLDNN_OPCHECK_RUN(LRNCompute, attrs, ctx, inputs, req, outputs); + // Copy outputs[1] from opcheck reference as backward check needs it. + MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector{1}); return; } FallBackCompute(LRNCompute, attrs, ctx, inputs, req, outputs); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index c6e7f9bdefdc..de53ab27f539 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -490,6 +490,9 @@ class OpCheck { const std::vector &inputs_, const std::vector &req, const std::vector &outputs_); + + void CopyResult(const std::vector &outputs_, + const std::vector& indice); }; #define MKLDNN_OPCHECK_INIT(backward, num_checks, inputs, outputs) \ @@ -499,6 +502,8 @@ class OpCheck { #define MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs) \ if (debug) check.Run(fn, attrs, ctx, inputs, req, outputs); +#define MKLDNN_OPCHECK_COPY_RESULT(outputs, indice) \ + if (debug) check.CopyResult(outputs, indice); } // namespace mxnet #endif diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 858f8e3261f2..20997eab6e2e 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -506,6 +506,17 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs, } } +void OpCheck::CopyResult(const std::vector &outputs_, + const std::vector &indice) { + CHECK(!MKLDNNStream::Get()->HasOps()); + auto non_const_outputs_ = const_cast &>(outputs_); + for (auto i = indice.begin(); i != indice.end(); ++i) { + auto mem = outputs[*i].GetMKLDNNData(); + non_const_outputs_[*i].CopyFrom(*mem); + } + MKLDNNStream::Get()->Submit(); +} + } // namespace mxnet #endif diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index adb72a2a9c46..7f143b149500 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -18,7 +18,7 @@ */ /*! - * \file mkldnn_lrn-inl.h + * \file mkldnn_lrn-inl.h * \brief * \Author: Patric Zhao, patric.zhao@intel.com */ @@ -40,9 +40,8 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { return algorithm::lrn_across_channels; } -inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam ¶m, - const bool is_train, - const memory::desc &src_md) { +inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc( + const LRNParam ¶m, const bool is_train, const memory::desc &src_md) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); const algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; @@ -59,11 +58,10 @@ inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam ¶m, return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine); } -inline mkldnn::lrn_backward::primitive_desc -GetLRNBwd(const LRNParam ¶m, - const mkldnn::memory::desc &diff_in_md, - const mkldnn::memory::desc &diff_md, - const lrn_forward::primitive_desc &lrnFwd_desc) { +inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc( + const LRNParam ¶m, const mkldnn::memory::desc &diff_in_md, + const mkldnn::memory::desc &diff_md, + const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); const algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; @@ -92,9 +90,10 @@ class MKLDNNLRNFwd { ~MKLDNNLRNFwd() {} - void SetDataHandle(const NDArray &data, - const NDArray &output); - + void SetDataHandle(const NDArray &data, const NDArray &output); + void SetDataHandle(const NDArray &data, const mkldnn::memory *output_mem); + const mkldnn::memory *GetWs(); + mkldnn::lrn_forward &GetFwd(); void Execute(); private: @@ -111,15 +110,17 @@ class MKLDNNLRNFwd { void MKLDNNLRNFwd::_Init(const LRNParam ¶m, bool is_train, const NDArray &in_data) { - mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_primitive_desc().desc(); - lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md); + mkldnn::memory::desc in_data_md = + in_data.GetMKLDNNData()->get_primitive_desc().desc(); + mkldnn::lrn_forward::primitive_desc fwd_pd = + GetLRNFwdDesc(param, is_train, in_data_md); this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData() ->get_primitive_desc())); this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc())); if (is_train) { - // If it's training, we have to create a workspace memory. Otherwise, MKLDNN - // will have segmentation fault. + // If it's training, we have to create a workspace memory. Otherwise, + // MKLDNN will have segmentation fault. ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc())); this->fwd = std::shared_ptr( new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem), @@ -133,13 +134,23 @@ void MKLDNNLRNFwd::_Init(const LRNParam ¶m, void MKLDNNLRNFwd::SetDataHandle(const NDArray &in_data, const NDArray &out_data) { - const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData(); - mkldnn::memory *out_data_mem = const_cast(out_data).CreateMKLDNNData( - this->out_mem->get_primitive_desc()); + mkldnn::memory *out_data_mem = + const_cast(out_data).CreateMKLDNNData( + this->out_mem->get_primitive_desc()); + this->SetDataHandle(in_data, out_data_mem); +} + +void MKLDNNLRNFwd::SetDataHandle(const NDArray &in_data, + const mkldnn::memory *out_data_mem) { + mkldnn::memory *in_data_mem = const_cast(in_data).CreateMKLDNNData( + this->in_mem->get_primitive_desc()); this->in_mem->set_data_handle(in_data_mem->get_data_handle()); this->out_mem->set_data_handle(out_data_mem->get_data_handle()); } +const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); } +mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; } + void MKLDNNLRNFwd::Execute() { MKLDNNStream::Get()->RegisterPrim(*(this->fwd)); MKLDNNStream::Get()->Submit(); @@ -158,16 +169,10 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, MKLDNNLRNFwd, OpHash> lrn_fwds; #endif - auto alg_ = algorithm::lrn_across_channels; - auto kind_ = prop_kind::forward_training; - if (ctx.is_train) { - kind_ = prop_kind::forward_training; - } else { - kind_ = prop_kind::forward_scoring; - } + auto kind_ = + ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring; MKLDNNLRNSignature key(param); - key.AddSign(alg_); key.AddSign(kind_); key.AddSign(in_data); @@ -182,16 +187,98 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, return it->second; } -void MKLDNNLRNForward(const OpContext &ctx, - const LRNParam ¶m, - const NDArray &in_data, - const OpReqType req, +void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m, + const NDArray &in_data, const OpReqType req, const NDArray &out_data) { - MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_data); + MKLDNNLRNFwd &fwd = GetLRNFwd(param, ctx, in_data); fwd.SetDataHandle(in_data, out_data); fwd.Execute(); } +// LRN Backward Class +class MKLDNNLRNBwd { + std::shared_ptr bwd; + std::shared_ptr in_data_mem; + std::shared_ptr diff_dst_mem; + std::shared_ptr ws_mem; + std::shared_ptr diff_src_mem; + + public: + const mkldnn::lrn_forward::primitive_desc fwd_pd; + const mkldnn::lrn_backward::primitive_desc bwd_pd; + + ~MKLDNNLRNBwd() {} + + MKLDNNLRNBwd(const LRNParam ¶m, const mkldnn::memory::desc in_data_md, + const mkldnn::memory::desc diff_md) + : fwd_pd(GetLRNFwdDesc(param, true, in_data_md)), + bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {} + + void SetDataHandle(const NDArray &in_data, const NDArray &out_grad, + const mkldnn::memory *ws, + const mkldnn::memory *diff_src_mem) { + if (bwd == nullptr) { + this->in_data_mem.reset( + new mkldnn::memory(this->fwd_pd.src_primitive_desc(), + in_data.GetMKLDNNData()->get_data_handle())); + this->diff_dst_mem.reset( + new mkldnn::memory(this->fwd_pd.dst_primitive_desc(), + out_grad.GetMKLDNNData()->get_data_handle())); + this->ws_mem.reset( + new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(), + ws->get_data_handle())); + this->diff_src_mem.reset( + new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(), + diff_src_mem->get_data_handle())); + this->bwd.reset(new mkldnn::lrn_backward( + this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem), + mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem, + *this->diff_src_mem)); + } else { + this->in_data_mem->set_data_handle( + in_data.GetMKLDNNData()->get_data_handle()); + this->diff_dst_mem->set_data_handle( + out_grad.GetMKLDNNData()->get_data_handle()); + this->ws_mem->set_data_handle(ws->get_data_handle()); + this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle()); + } + } + + void Execute() { + MKLDNNStream::Get()->RegisterPrim(*(this->bwd)); + MKLDNNStream::Get()->Submit(); + } +}; // End of LRN Class + +static MKLDNNLRNBwd &GetLRNBwd(const LRNParam ¶m, const NDArray &in_data, + const NDArray &in_grad, const NDArray &out_grad) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local + std::unordered_map lrn_bwds; +#else + static MX_THREAD_LOCAL + std::unordered_map lrn_bwds; +#endif + MKLDNNLRNSignature key(param); + key.AddSign(in_data); + key.AddSign(in_grad); + key.AddSign(out_grad); + + auto it = lrn_bwds.find(key); + if (it == lrn_bwds.end()) { + const mkldnn::memory::desc in_data_md = + in_data.GetMKLDNNData()->get_primitive_desc().desc(); + const mkldnn::memory::desc diff_md = + out_grad.GetMKLDNNData()->get_primitive_desc().desc(); + MKLDNNLRNBwd bwd(param, in_data_md, diff_md); + auto ins_ret = + lrn_bwds.insert(std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, const NDArray &out_grad, const NDArray &in_data, @@ -200,36 +287,22 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, if (req == kNullOp) { return; } + MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_data, in_grad, out_grad); // Repeat FW for getting workspace - const mkldnn::memory *data_mem = in_data.GetMKLDNNData(); - const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc(); - const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, ctx.is_train, - data_md); - // TODO(Patric): To keep the function stateless, we can't pass workspace // from LRN forward to backward. We have to re-compute // LRN forward to get the workspace. // Will refine this code later. - std::shared_ptr ws_mem( - new mkldnn::memory(pdesc_fwd.workspace_primitive_desc())); + MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_data); std::shared_ptr dst_temp( - new mkldnn::memory(pdesc_fwd.dst_primitive_desc())); - MKLDNNStream::Get()->RegisterPrim( - lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem), - *ws_mem, *dst_temp)); - - const mkldnn::memory::desc data_in_md = pdesc_fwd.src_primitive_desc().desc(); - const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData(); - const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc(); - const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, data_in_md, - diff_md, pdesc_fwd); - mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad, - pdesc_bwd.diff_src_primitive_desc(), req); - - MKLDNNStream::Get()->RegisterPrim( - lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem), - mkldnn::primitive::at(*diff_mem), *ws_mem, *diff_src_mem.second)); - MKLDNNStream::Get()->Submit(); + new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc())); + fwd.SetDataHandle(in_data, dst_temp.get()); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + + mkldnn_output_t diff_src_mem = + CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req); + bwd.SetDataHandle(in_data, out_grad, fwd.GetWs(), diff_src_mem.second); + bwd.Execute(); } } // namespace op } // namespace mxnet From d6dc8a8408614eeccb02327fa63c296cebe85ebe Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 13 Jun 2018 12:17:28 +0800 Subject: [PATCH 02/18] Enable primitive allocation cache for _backward_Pooling. Change-Id: Idbe94e21f1e2ddf711523767194b95beda19b120 --- src/operator/nn/mkldnn/mkldnn_pooling-inl.h | 21 +++ src/operator/nn/mkldnn/mkldnn_pooling.cc | 186 +++++++++++++------- 2 files changed, 143 insertions(+), 64 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 691e1d371b5b..aa9f548f594d 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -78,6 +78,27 @@ class MKLDNNPoolingFwd { const int padding_l, const int padding_r); }; +class MKLDNNPoolingBwd { + std::shared_ptr bwd; + std::shared_ptr diff_dst; + std::shared_ptr diff_src; + std::shared_ptr ws; + bool with_workspace; + + public: + const mkldnn::pooling_backward::primitive_desc pd; + + MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc, + bool with_ws); + + ~MKLDNNPoolingBwd() {} + void SetDataHandle(const mxnet::NDArray *workspace, + const mxnet::NDArray &out_grad, + const mkldnn::memory *diff_src_mem); + const mkldnn::pooling_backward &GetBwd(); + const mkldnn::pooling_backward::primitive_desc &GetPd(); +}; + inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { return param.kernel.ndim() == 2 && (param.pool_type == pool_enum::kMaxPooling || diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 9fd88a13c465..ee08ddf94a40 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -133,10 +133,9 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { } } -mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam ¶m, - const bool is_train, - const memory::desc &data_md, - const memory::desc &out_md) { +mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( + const PoolingParam ¶m, const bool is_train, const memory::desc &data_md, + const memory::desc &out_md) { CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; int kernel_h_, kernel_w_; if (param.global_pool) { @@ -254,11 +253,124 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data, const NDArray *workspace) { - auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); + auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); fwd.SetDataHandle(in_data, out_data, workspace); fwd.Execute(); } +MKLDNNPoolingBwd::MKLDNNPoolingBwd( + const pooling_backward::primitive_desc &pdesc, bool with_ws) + : with_workspace(with_ws), pd(pdesc) {} + +void MKLDNNPoolingBwd::SetDataHandle(const mxnet::NDArray *workspace, + const mxnet::NDArray &out_grad, + const mkldnn::memory *diff_src_mem) { + if (bwd == nullptr) { + diff_dst.reset( + new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(), + out_grad.GetMKLDNNData()->get_data_handle())); + diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(), + diff_src_mem->get_data_handle())); + if (with_workspace) { + CHECK(workspace != nullptr); + ws.reset( + new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(), + workspace->GetMKLDNNData()->get_data_handle())); + bwd.reset( + new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src)); + } else { + bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src)); + } + } else { + diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle()); + diff_src->set_data_handle(diff_src_mem->get_data_handle()); + if (with_workspace) { + CHECK(workspace != nullptr); + ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle()); + } + } +} + +const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() { + return *this->bwd; +} + +MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, + const NDArray &in_data, + const NDArray &in_grad, + const NDArray &out_grad) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local + std::unordered_map pooling_bwds; +#else + static MX_THREAD_LOCAL + std::unordered_map pooling_bwds; +#endif + + bool with_workspace = MKLDNNRequireWorkspace(param); + MKLDNNPoolingSignature key(param); + key.AddSign(in_data); + key.AddSign(in_grad); + key.AddSign(out_grad); + + auto it = pooling_bwds.find(key); + if (it == pooling_bwds.end()) { + auto diff_dst_mem = out_grad.GetMKLDNNData(); + auto input_mem = in_data.GetMKLDNNData(); + mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); + const mkldnn::memory::desc data_md = data_mpd.desc(); + const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], + static_cast(out_grad.shape()[2]), + static_cast(out_grad.shape()[3])}; + const memory::desc out_md( + {dims}, static_cast(data_md.data.data_type), + static_cast(data_md.data.format)); + auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md); + + const mkldnn::memory::desc diff_md = + diff_dst_mem->get_primitive_desc().desc(); + const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], + static_cast(in_grad.shape()[2]), + static_cast(in_grad.shape()[3])}; + const memory::desc diff_in_md( + {dims1}, static_cast(diff_md.data.data_type), + static_cast(diff_md.data.format)); + const mkldnn::engine cpu_engine = data_mpd.get_engine(); + const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); + + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + + int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; + int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; + int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + if (param.global_pool) { + pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; + stride_h_ = stride_w_ = 1; + } + + const pooling_backward::desc desc( + alg, diff_in_md, diff_md, {stride_h_, stride_w_}, + {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}, + mkldnn::padding_kind::zero); + const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); + MKLDNNPoolingBwd bwd(pdesc, with_workspace); + auto ins_ret = pooling_bwds.insert( + std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &out_grad, const NDArray &in_data, const NDArray *workspace, const OpReqType req, @@ -266,68 +378,14 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, if (req == kNullOp) { return; } - TmpMemMgr::Get()->Init(ctx.requested[0]); - // mkldnn::memory - auto diff_dst_mem = out_grad.GetMKLDNNData(); - auto input_mem = in_data.GetMKLDNNData(); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - const mkldnn::memory::desc data_md = data_mpd.desc(); - const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], - static_cast(out_grad.shape()[2]), - static_cast(out_grad.shape()[3])}; - const memory::desc out_md({dims}, - static_cast(data_md.data.data_type), - static_cast(data_md.data.format)); - auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md); - - const mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc(); - const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], - static_cast(in_grad.shape()[2]), - static_cast(in_grad.shape()[3])}; - const memory::desc diff_in_md( - {dims1}, static_cast(diff_md.data.data_type), - static_cast(diff_md.data.format)); - const mkldnn::engine cpu_engine = data_mpd.get_engine(); - const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); - - int kernel_h_, kernel_w_; - if (param.global_pool) { - kernel_h_ = data_md.data.dims[2]; - kernel_w_ = data_md.data.dims[3]; - } else { - kernel_h_ = param.kernel[0]; - kernel_w_ = param.kernel[1]; - } - - int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; - int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; - int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; - if (param.global_pool) { - pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; - stride_h_ = stride_w_ = 1; - } - - const pooling_backward::desc desc(alg, diff_in_md, diff_md, - {stride_h_, stride_w_}, - {kernel_h_, kernel_w_}, - {pad_t_, pad_l_}, {pad_b_, pad_r_}, - mkldnn::padding_kind::zero); - const pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd); + auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); auto diff_src_mem = - CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req); - - if (MKLDNNRequireWorkspace(param)) { - CHECK(workspace != nullptr); - auto workspace_mem = workspace->GetMKLDNNData(); - MKLDNNStream::Get()->RegisterPrim( - pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem), - *diff_src_mem.second)); - } else { - MKLDNNStream::Get()->RegisterPrim( - pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second)); - } + CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); + + bwd.SetDataHandle(workspace, out_grad, diff_src_mem.second); + MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); CommitOutput(in_grad, diff_src_mem); MKLDNNStream::Get()->Submit(); } From 9e107d2b323feda95ad92b8c041005f669eef473 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 12 Jun 2018 16:49:38 +0800 Subject: [PATCH 03/18] Enable primitive allocation cache for _backward_Activation. Change-Id: I545628ff68a54cb01b7fef323dc3de9bd47b1a19 --- src/operator/nn/mkldnn/mkldnn_act.cc | 125 ++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 24 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index b21d1238f7aa..e7fe79d0e5ff 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, stream->Submit(); } +static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( + const ActivationParam ¶m, const mkldnn::memory &input_mem, + const mkldnn::memory &diff_dst_memory, int dtype) { + mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc(); + auto cpu_engine = data_mpd.get_engine(); + auto alg = GetMKLDNNActAlgo(param); + + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType alpha = 0; + mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, + alg, data_md, alpha); + mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); + mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, + fw_pdesc); + return bw_pdesc; + }); + LOG(INFO) << "Unsupported data type for MKLDNN activation"; + mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, + alg, data_md, 0.0); + mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0); + mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, + fw_pdesc); + return bw_pdesc; +} + +class MKLDNNActBackward { + std::shared_ptr bwd; + std::shared_ptr data; + std::shared_ptr diff_dst_memory; + std::shared_ptr diff_src_memory; + + public: + const mkldnn::eltwise_backward::primitive_desc pd; + + explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data, + const mkldnn::memory &mem, + const mkldnn::memory &diff_dst_memory) + : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {} + + void SetNewMem(const mkldnn::memory &data, + const mkldnn::memory &diff_dst_memory, + const mkldnn::memory &diff_src_memory) { + if (this->bwd != nullptr) { + this->data->set_data_handle(data.get_data_handle()); + this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle()); + this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle()); + } else { + this->data = std::shared_ptr(new mkldnn::memory( + data.get_primitive_desc(), data.get_data_handle())); + this->diff_dst_memory = std::shared_ptr( + new mkldnn::memory(diff_dst_memory.get_primitive_desc(), + diff_dst_memory.get_data_handle())); + this->diff_src_memory = std::shared_ptr( + new mkldnn::memory(diff_src_memory.get_primitive_desc(), + diff_src_memory.get_data_handle())); + this->bwd = std::shared_ptr( + new mkldnn::eltwise_backward( + this->pd, mkldnn::primitive::at(*this->data), + *this->diff_dst_memory, *this->diff_src_memory)); + } + } + + const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; } +}; + +static inline MKLDNNActBackward &GetActBackward(const ActivationParam ¶m, + const OpContext &ctx, + const NDArray &in_data, + const NDArray &out_grad, + const mkldnn::memory &in_mem) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + MKLDNNActSignature key(param); + key.AddSign(in_data); + key.AddSign(out_grad); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData()); + auto ins_ret = + bwds.insert(std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + // For backward relu activation, it's okay to pass "out_data" as "in_data" to this // function, since the computation only involes non-zeros. void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -193,37 +287,20 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx in_buffer = in_data.Reorder2Default(); const ActivationParam& param = nnvm::get(attrs.parsed); - TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); auto diff_dst_memory = out_buffer.GetMKLDNNData(); auto input_mem = in_buffer.GetMKLDNNData(); // We need to make sure the two inputs to eltwise_backward has the same memory // descriptor. Otherwise, the perf will suffer. if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc(); - auto cpu_engine = data_mpd.get_engine(); - + MKLDNNActBackward &bwd = + GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); - auto alg = GetMKLDNNActAlgo(param); - mkldnn_output_t diff_src_memory; - - MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, { - DType alpha = 0; - mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, - alg, data_md, alpha); - mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); - mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); - mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, - fw_pdesc); - - diff_src_memory = CreateMKLDNNMem(in_grad, - bw_pdesc.diff_src_primitive_desc(), req); - stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem, - *diff_dst_memory, - *diff_src_memory.second)); - }); + TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); + mkldnn_output_t diff_src_memory = + CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); + bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); + stream->RegisterPrim(bwd.GetBwd()); CommitOutput(in_grad, diff_src_memory); stream->Submit(); } From b2b71e1bafa3d9aa519042ff9487e177c8a05cd8 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 13 Jun 2018 10:17:51 +0800 Subject: [PATCH 04/18] Enable primitive allocation cache for _backward_Deconvolution. Change-Id: I1e9bf1b9b44bae52068a9c564dff037851e896e5 --- .../nn/mkldnn/mkldnn_deconvolution.cc | 225 +++++++++++++++--- 1 file changed, 187 insertions(+), 38 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 7f3676a70dd0..54d4f6708524 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -93,9 +93,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( return mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd); } -static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( - const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights, - bool has_bias, const NDArray &output) { +static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, bool has_bias, const NDArray &output) { auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); @@ -116,9 +116,10 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( strides, padding, dilate); } -static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( - const DeconvolutionParam& param, const NDArray &data, const NDArray &weights, - bool has_bias, const NDArray &output, +static mkldnn::convolution_backward_weights::primitive_desc +GetDeconvBwdWeightsImpl( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, bool has_bias, const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) { auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.num_group); @@ -308,55 +309,203 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data); } -void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +class MKLDNNDeconvBackwardData { + std::shared_ptr bwd; + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr out; + + public: + const mkldnn::convolution_forward::primitive_desc pd; + + MKLDNNDeconvBackwardData(const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output) + : pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) { + } + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory &output) { + if (bwd == nullptr) { + this->data = std::shared_ptr( + new mkldnn::memory(pd.src_primitive_desc(), data.get_data_handle())); + this->weight = std::shared_ptr( + new mkldnn::memory(pd.weights_primitive_desc(), weight.get_data_handle())); + this->out = std::shared_ptr( + new mkldnn::memory(pd.dst_primitive_desc(), output.get_data_handle())); + bwd = std::shared_ptr( + new mkldnn::convolution_forward(pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + *this->out)); + } else { + this->data->set_data_handle(data.get_data_handle()); + this->weight->set_data_handle(weight.get_data_handle()); + this->out->set_data_handle(output.get_data_handle()); + } + } + + const mkldnn::convolution_forward &GetBwd() const { return *bwd; } +}; + +typedef ParamOpSign MKLDNNDeconvSignature; + +static inline MKLDNNDeconvBackwardData &GetDeconvBwdData( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map + bwds; +#else + static MX_THREAD_LOCAL std::unordered_map + bwds; +#endif + MKLDNNDeconvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNDeconvBackwardData bwd(param, data, weights, output); + auto ins_ret = bwds.insert( + std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +class MKLDNNDeconvBackwardWeights { + std::shared_ptr bwd; + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr out; + + public: + const mkldnn::convolution_backward_weights::primitive_desc pd; + + MKLDNNDeconvBackwardWeights( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) + : pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output, + bwd_data_pd)) {} + + void SetNewMem( + const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory &output, + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) { + if (bwd == nullptr) { + this->data = std::shared_ptr(new mkldnn::memory( + bwd_data_pd.src_primitive_desc(), data.get_data_handle())); + this->weight = std::shared_ptr(new mkldnn::memory( + bwd_data_pd.weights_primitive_desc(), weight.get_data_handle())); + this->out = std::shared_ptr(new mkldnn::memory( + bwd_data_pd.dst_primitive_desc(), output.get_data_handle())); + bwd = std::shared_ptr( + new mkldnn::convolution_backward_weights(pd, *this->data, + *this->weight, *this->out)); + } else { + this->data->set_data_handle(data.get_data_handle()); + this->weight->set_data_handle(weight.get_data_handle()); + this->out->set_data_handle(output.get_data_handle()); + } + } + + const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; } +}; + +static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map + bwds; +#else + static MX_THREAD_LOCAL std::unordered_map + bwds; +#endif + MKLDNNDeconvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNDeconvBackwardWeights bwd(param, data, weights, output, bwd_data_pd); + auto ins_ret = bwds.insert( + std::pair(key, + bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); const std::vector &in_grad = outputs; - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace"; - mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData( - param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], false, - inputs[deconv::kOut]); + const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); + CHECK_NE(req[deconv::kWeight], kWriteInplace) + << "cannot write weight inplace"; + MKLDNNDeconvBackwardData &bwd_data = + GetDeconvBwdData(param, inputs[deconv::kData + 1], + inputs[deconv::kWeight + 1], inputs[deconv::kOut]); auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( - bwdData_pd.src_primitive_desc()); + bwd_data.pd.src_primitive_desc()); if (req[deconv::kData]) { - auto weight_mem = GetWeights(inputs[deconv::kWeight + 1], - bwdData_pd.weights_primitive_desc(), - param.num_group); - auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData], - bwdData_pd.dst_primitive_desc(), - req[deconv::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_forward(bwdData_pd, - *out_grad_mem, *weight_mem, *in_grad_mem.second)); + auto weight_mem = + GetWeights(inputs[deconv::kWeight + 1], + bwd_data.pd.weights_primitive_desc(), param.num_group); + auto in_grad_mem = + CreateMKLDNNMem(in_grad[deconv::kData], + bwd_data.pd.dst_primitive_desc(), req[deconv::kData]); + bwd_data.SetNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); + MKLDNNStream::Get()->RegisterPrim(bwd_data.GetBwd()); CommitOutput(in_grad[deconv::kData], in_grad_mem); } if (req[deconv::kWeight]) { - mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd - = GetDeconvBwdWeights(param, inputs[deconv::kData + 1], - inputs[deconv::kWeight + 1], false, inputs[deconv::kOut], bwdData_pd); - if (bwdData_pd.src_primitive_desc() != bwdWeights_pd.src_primitive_desc()) + MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights( + param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], + inputs[deconv::kOut], bwd_data.pd); + if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc()) out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( - bwdWeights_pd.src_primitive_desc()); + bwd_weights.pd.src_primitive_desc()); auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder( - bwdWeights_pd.diff_dst_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[deconv::kWeight], - bwdWeights_pd.diff_weights_primitive_desc(), - req[deconv::kWeight]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second)); + bwd_weights.pd.diff_dst_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad( + in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(), + req[deconv::kWeight]); + bwd_weights.SetNewMem(*out_grad_mem, *data_mem, *in_grad_weight.second, bwd_data.pd); + MKLDNNStream::Get()->RegisterPrim(bwd_weights.GetBwd()); CommitOutput(in_grad[deconv::kWeight], in_grad_weight); } MKLDNNStream::Get()->Submit(); if (!param.no_bias) { typedef float DType; Stream *s = ctx.get_stream(); - Tensor gbias = in_grad[deconv::kBias].data().get(s); + Tensor gbias = + in_grad[deconv::kBias].data().get(s); // If there is bias, the out grad has already been converted to the default // format, so this shouldn't cause any performance issues. - Tensor grad = inputs[deconv::kOut].data().get(s); - Assign(gbias, req[deconv::kBias], mshadow::expr::sumall_except_dim<1>(grad)); + Tensor grad = + inputs[deconv::kOut].data().get(s); + Assign(gbias, req[deconv::kBias], + mshadow::expr::sumall_except_dim<1>(grad)); } } From a58ad331569b457966b0f066a4d730b31c2259ad Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 13 Jun 2018 12:19:20 +0800 Subject: [PATCH 05/18] Enable primitive allocation cache for _backward_BatchNorm. Change-Id: I9e52651bd830b8cb5d2f193076ef51606c9056f9 --- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 129 ++++++++++++------ 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 9046836e8e75..496ff99f4ee9 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -290,6 +290,84 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, } } +class MKLDNNBNBackward { + std::shared_ptr bwd; + std::shared_ptr data_m; + std::shared_ptr diff_m; + std::shared_ptr gradi_m; + std::shared_ptr mean_m; + std::shared_ptr var_m; + const std::shared_ptr weight_m; + const std::shared_ptr gradw_m; + + public: + const t_bn_b_pdesc pd; + + explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd) + : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())), + gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())), + pd(_pd) {} + + const mkldnn::memory &GetWeight() const { return *weight_m; } + + const mkldnn::memory &GetGradw() const { return *gradw_m; } + + void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff, + const NDArray &mean, const mkldnn::memory &var, + const mkldnn::memory &gradi) { + auto mean_ptr = mean.data().dptr_; + if (bwd == nullptr) { + data_m.reset(new mkldnn::memory(data.get_primitive_desc(), + data.get_data_handle())); + diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(), + diff.get_data_handle())); + gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(), + gradi.get_data_handle())); + mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr)); + var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(), + var.get_data_handle())); + bwd.reset(new mkldnn::batch_normalization_backward( + pd, *data_m, mkldnn::primitive::at(*mean_m), + mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m, + *gradw_m)); + } else { + data_m->set_data_handle(data.get_data_handle()); + diff_m->set_data_handle(diff.get_data_handle()); + gradi_m->set_data_handle(gradi.get_data_handle()); + mean_m->set_data_handle(mean_ptr); + var_m->set_data_handle(var.get_data_handle()); + } + } + + const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; } +}; + +template +static MKLDNNBNBackward &GetBNBackward( + const BatchNormParam ¶m, const OpContext &ctx, const NDArray &in_data, + const mkldnn::memory &in_mem, const NDArray &diff_data, + const mkldnn::memory &diff_mem, unsigned flags) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + MKLDNNBNSignature key(param); + key.AddSign(in_data); + key.AddSign(diff_data); + + auto it = bwds.find(key); + if (it == bwds.end()) { + auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags); + MKLDNNBNBackward bwd(bwd_pd); + auto ins_ret = + bwds.insert(std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + template void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_grad, @@ -326,17 +404,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc()); else if (diff.IsDefaultData()) diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc()); - auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags); + auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc()); if (flags & use_scale_shift) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; - // TODO(tao): how to reuse this memory? - std::shared_ptr weight_mem( - new mkldnn::memory(bwd_pd.weights_primitive_desc())); - - DType* weight_buf = reinterpret_cast(weight_mem->get_data_handle()); + DType *weight_buf = reinterpret_cast(bwd.GetWeight().get_data_handle()); nnvm::dim_t channels_ = data.shape()[1]; for (int i = 0; i < channels_; i++) { if (!param.fix_gamma) @@ -349,15 +423,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, weight_buf[channels_ + i] = (beta.data().dptr())[i]; // bias } - std::shared_ptr gradw_mem( - new mkldnn::memory(bwd_pd.diff_weights_primitive_desc())); // training but no input mean and variance if (ctx.is_train && !param.use_global_stats) { DType* moving_mean_ptr = reinterpret_cast(moving_mean.data().dptr()); DType* moving_var_ptr = reinterpret_cast(moving_var.data().dptr()); DType* out_mean_ptr = reinterpret_cast(out_mean.data().dptr()); DType* out_var_ptr = reinterpret_cast(out_var.data().dptr()); - mkldnn::memory var_mem(bwd_pd.variance_primitive_desc()); + mkldnn::memory var_mem(bwd.pd.variance_primitive_desc()); DType *tmp_var_ptr = reinterpret_cast(var_mem.get_data_handle()); DType minus_mom = (1.0f - param.momentum); @@ -369,45 +441,18 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom; } - - std::shared_ptr out_mean_mem( - new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr)); - std::shared_ptr out_var_mem( - new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr)); - - auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd, - *data_mem, - mkldnn::primitive::at(*out_mean_mem), - mkldnn::primitive::at(var_mem), - *diff_mem, - *weight_mem, - *gradi_mem, - *gradw_mem); - - MKLDNNStream::Get()->RegisterPrim(bn_bwd); + bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem); + MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); MKLDNNStream::Get()->Submit(); } else { - std::shared_ptr imean_mem( - new mkldnn::memory(bwd_pd.mean_primitive_desc(), - moving_mean.data().dptr())); - std::shared_ptr ivar_mem( - new mkldnn::memory(bwd_pd.variance_primitive_desc(), - moving_var.data().dptr())); - auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd, - *data_mem, - mkldnn::primitive::at(*imean_mem), - mkldnn::primitive::at(*ivar_mem), - *diff_mem, - *weight_mem, - *gradi_mem, - *gradw_mem); - - MKLDNNStream::Get()->RegisterPrim(bn_bwd); + bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean, + *moving_var.GetMKLDNNData(), *gradi_mem); + MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); MKLDNNStream::Get()->Submit(); } // copy data from gradw_mem to in_grad[1] and in_grad[2] - DType* gw_buf = reinterpret_cast(gradw_mem->get_data_handle()); + DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); for (int i = 0; i < channels_; i++) { if (!param.fix_gamma) (in_grad[1].data().dptr())[i] = gw_buf[i]; From 97e0d34da384c3e245dd97c68e23696a59c0aa70 Mon Sep 17 00:00:00 2001 From: "Huang, Zhiyuan" Date: Wed, 13 Jun 2018 12:45:27 +0800 Subject: [PATCH 06/18] Enable primitive allocation cache for _backward_Convolution Change-Id: I0496fa2394ee036d05c58f3abc1d74af544c7bca --- src/operator/nn/mkldnn/mkldnn_convolution.cc | 205 ++++++++++++++++--- 1 file changed, 178 insertions(+), 27 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index cf04ea8da3d7..f6ad02cc0091 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -283,6 +283,157 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx MKLDNNStream::Get()->Submit(); } +class MKLDNNConvBackward { + std::shared_ptr bwd_data; + std::shared_ptr bwd_weight; + // conv::kData + std::shared_ptr out_grad; + std::shared_ptr in_grad; + std::shared_ptr weight; + // conv::kWeight + std::shared_ptr data; + std::shared_ptr output; + std::shared_ptr in_grad_weight; + std::shared_ptr in_grad_bias; + + public: + mkldnn::convolution_backward_data::primitive_desc bwdData_pd; + mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd; + + MKLDNNConvBackward( + const ConvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray *bias, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &fwd_pd): + bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)), + bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) { + } + + void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight, + const mkldnn::memory &in_grad) { + if (this->out_grad == nullptr) + this->out_grad = std::shared_ptr(new mkldnn::memory( + bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->out_grad->set_data_handle(out_grad.get_data_handle()); + if (this->in_grad == nullptr) + this->in_grad = std::shared_ptr(new mkldnn::memory( + bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); + else + this->in_grad->set_data_handle(in_grad.get_data_handle()); + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + bwdData_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + if (this->bwd_data == nullptr) + this->bwd_data = std::shared_ptr( + new mkldnn::convolution_backward_data( + this->bwdData_pd, mkldnn::primitive::at(*this->out_grad), + mkldnn::primitive::at(*this->weight), *this->in_grad)); + } + +void SetWeightNewMem(const mkldnn::memory &data, + const mkldnn::memory &out_grad, + const mkldnn::memory &in_grad_weight) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + if (this->output == nullptr) + this->output = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->output->set_data_handle(out_grad.get_data_handle()); + if (this->in_grad_weight == nullptr) + this->in_grad_weight = std::shared_ptr( + new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), + in_grad_weight.get_data_handle())); + else + this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); + + if (this->bwd_weight == nullptr) + this->bwd_weight = std::shared_ptr( + new mkldnn::convolution_backward_weights( + this->bwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->output), *this->in_grad_weight)); + } + + void SetWeightNewMem(const mkldnn::memory &data, + const mkldnn::memory &out_grad, + const mkldnn::memory &in_grad_weight, + const mkldnn::memory &in_grad_bias) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + if (this->output == nullptr) + this->output = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->output->set_data_handle(out_grad.get_data_handle()); + if (this->in_grad_weight == nullptr) + this->in_grad_weight = std::shared_ptr( + new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), + in_grad_weight.get_data_handle())); + else + this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); + + if (this->in_grad_bias == nullptr) + this->in_grad_bias = std::shared_ptr( + new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(), + in_grad_bias.get_data_handle())); + else + this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); + if (this->bwd_weight == nullptr) + this->bwd_weight = std::shared_ptr( + new mkldnn::convolution_backward_weights( + this->bwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->output), *this->in_grad_weight, + *this->in_grad_bias)); + } + + const mkldnn::convolution_backward_data &GetBwdData() const { + return *bwd_data; + } + + const mkldnn::convolution_backward_weights &GetBwdWeights() const { + return *bwd_weight; + } +}; + +static inline MKLDNNConvBackward &GetConvBwd( + const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &fwd_pd) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + const ConvolutionParam& param = nnvm::get(attrs.parsed); + MKLDNNConvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + if (bias) + key.AddSign(*bias); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd); + auto ins_ret = bwds.insert( + std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& inputs, const std::vector& req, @@ -295,44 +446,45 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]); CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; - mkldnn::convolution_backward_data::primitive_desc bwdData_pd - = GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], - inputs[conv::kOut], fwd_pd); + MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1], + inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd); auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( - bwdData_pd.diff_dst_primitive_desc()); + convBwd.bwdData_pd.diff_dst_primitive_desc()); if (req[conv::kData]) { auto weight_mem = GetWeights(inputs[conv::kWeight + 1], - bwdData_pd.weights_primitive_desc(), param.num_group); + convBwd.bwdData_pd.weights_primitive_desc(), param.num_group); auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], - bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd, - *out_grad_mem, *weight_mem, *in_grad_mem.second)); + convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); + convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); + MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData()); CommitOutput(in_grad[conv::kData], in_grad_mem); } if (req[conv::kWeight]) { - mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd - = GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], - param.no_bias ? nullptr : &inputs[conv::kBias + 1], - inputs[conv::kOut], fwd_pd); - if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc()) + MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1], + inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1], + inputs[conv::kOut], fwd_pd); + if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() != + convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()) out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( - bwdWeights_pd.diff_dst_primitive_desc()); + convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()); auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder( - bwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight], - bwdWeights_pd.diff_weights_primitive_desc(), - req[conv::kWeight]); + convBwdWeight.bwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad( + in_grad[conv::kWeight], + convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(), + req[conv::kWeight]); mkldnn_output_t in_grad_bias; if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); + convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, + *in_grad_weight.second); + MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); } else { - in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], - bwdWeights_pd.diff_bias_primitive_desc(), - req[conv::kBias]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, - *in_grad_bias.second)); + in_grad_bias = CreateMKLDNNMem( + in_grad[conv::kBias], + convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]); + convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, + *in_grad_weight.second, *in_grad_bias.second); + MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); CommitOutput(in_grad[conv::kBias], in_grad_bias); } CommitOutput(in_grad[conv::kWeight], in_grad_weight); @@ -342,5 +494,4 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct } // namespace op } // namespace mxnet - #endif // MXNET_USE_MKLDNN == 1 From f7b9d3028427c0c8c1f58fef544ff7cc764a43d9 Mon Sep 17 00:00:00 2001 From: "Huang, Zhiyuan" Date: Wed, 13 Jun 2018 22:28:24 +0800 Subject: [PATCH 07/18] Enable primitive allocation cache for _backward_Fully_Connected Change-Id: I8347527ec1271b1518921a74e3581d7d84187429 --- src/operator/nn/fully_connected-inl.h | 17 + .../nn/mkldnn/mkldnn_fully_connected.cc | 307 +++++++++++++++--- 2 files changed, 286 insertions(+), 38 deletions(-) diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index 7eba2e20e573..bff582661189 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -60,6 +60,11 @@ struct FullyConnectedParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(flatten).set_default(true) .describe("Whether to collapse all but the first axis of the input data tensor."); } + bool operator==(const FullyConnectedParam& other) const { + return this->num_hidden == other.num_hidden && + this->no_bias == other.no_bias && + this->flatten == other.flatten; + } }; template @@ -227,4 +232,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::FullyConnectedParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.num_hidden); + ret = dmlc::HashCombine(ret, val.no_bias); + ret = dmlc::HashCombine(ret, val.flatten); + return ret; + } +}; +} // namespace std #endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index f86f8dbefa2b..12e75612fe06 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -82,6 +82,101 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei } } +class MKLDNNFullyConnectForward { + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr out; + std::shared_ptr bias; + std::shared_ptr ipFwd; + + public: + mkldnn::inner_product_forward::primitive_desc ipFwd_pd; + + MKLDNNFullyConnectForward(const FullyConnectedParam ¶m, bool is_train, + const NDArray &data, const NDArray &weight, + const NDArray *bias, + const mkldnn::memory::desc &output) + : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {} + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory *bias, const mkldnn::memory &output) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + + if (this->out == nullptr) + this->out = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (bias != nullptr) { + if (this->bias == nullptr) + this->bias = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.bias_primitive_desc(), bias->get_data_handle())); + else + this->bias->set_data_handle(bias->get_data_handle()); + if (this->ipFwd == nullptr) + this->ipFwd = std::shared_ptr( + new mkldnn::inner_product_forward( + ipFwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + mkldnn::primitive::at(*this->bias), *this->out)); + } else if (this->ipFwd == nullptr) { + this->ipFwd = std::shared_ptr( + new mkldnn::inner_product_forward( + ipFwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), *this->out)); + } + } + + const mkldnn::inner_product_forward &GetIpFwd() const { + return *ipFwd; + } +}; + +typedef ParamOpSign MKLDNNFullyconSignature; + +static inline MKLDNNFullyConnectForward &GetFCFwd( + const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight, + const NDArray *bias, const mkldnn::memory::desc &output, + const bool is_train) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fcFwds; +#else + static MX_THREAD_LOCAL std::unordered_map fcFwds; +#endif + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + MKLDNNFullyconSignature key(param); + key.AddSign(data); + key.AddSign(weight); + key.AddSign(is_train); + + if (bias) + key.AddSign(*bias); + + auto it = fcFwds.find(key); + if (it == fcFwds.end()) { + MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias, + output); + auto ins_ret = fcFwds.insert( + std::pair(key, fcFwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -112,25 +207,168 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), mkldnn::memory::format::any); } - - mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train); - auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc()); + MKLDNNFullyConnectForward &FCFwd = + GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias], + out_md, ctx.is_train); + auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc()); auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], - ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward( - ipFwd_pd, *data_mem, *weight_mem, *out_mem.second)); + FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); + if (!param.no_bias) { + auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder( + FCFwd.ipFwd_pd.bias_primitive_desc()); + FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); } else { - auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc()); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd, - *data_mem, *weight_mem, *bias_mem, *out_mem.second)); + FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); } + MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd()); CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Get()->Submit(); } +class MKLDNNFullyConnectBackward { + std::shared_ptr ipBwdData; + std::shared_ptr ipBwdWeight; + std::shared_ptr out_grad; + std::shared_ptr weight; + std::shared_ptr in_grad; + std::shared_ptr data; + std::shared_ptr in_grad_weight; + std::shared_ptr in_grad_bias; + + public: + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd; + mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd; + + public: + MKLDNNFullyConnectBackward( + const FullyConnectedParam ¶m, const NDArray &data, + const NDArray &weight, const std::vector &in_grad, + const NDArray &out_grad, const std::vector &req, + const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd) + : ipBwdData_pd(GetIpBwdData(data, weight, out_grad, ipFwd_pd)), + ipBwdWeights_pd(GetIPBwdWeights( + data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], + out_grad, ipFwd_pd)) {} + + void SetNewMemData(const mkldnn::memory &out_grad, + const mkldnn::memory &weight, + const mkldnn::memory &in_grad) { + if (this->out_grad == nullptr) + this->out_grad = std::shared_ptr(new mkldnn::memory( + ipBwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->out_grad->set_data_handle(out_grad.get_data_handle()); + + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + ipBwdData_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + + if (this->in_grad == nullptr) + this->in_grad = std::shared_ptr(new mkldnn::memory( + ipBwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); + else + this->in_grad->set_data_handle(in_grad.get_data_handle()); + + if (this->ipBwdData == nullptr) + this->ipBwdData = std::shared_ptr( + new mkldnn::inner_product_backward_data( + this->ipBwdData_pd, mkldnn::primitive::at(*this->out_grad), + mkldnn::primitive::at(*this->weight), *this->in_grad)); + } + + void SetNewWeightMem(const FullyConnectedParam ¶m, + const mkldnn::memory &data, + const mkldnn::memory &out_grad, + const mkldnn::memory &in_grad_weight, + const mkldnn::memory &in_grad_bias) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + ipBwdWeights_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->out_grad == nullptr) + this->out_grad = std::shared_ptr(new mkldnn::memory( + ipBwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->out_grad->set_data_handle(out_grad.get_data_handle()); + + if (this->in_grad_weight == nullptr) + this->in_grad_weight = std::shared_ptr( + new mkldnn::memory(ipBwdWeights_pd.diff_weights_primitive_desc(), + in_grad_weight.get_data_handle())); + else + this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); + + if (!param.no_bias) { + if (this->in_grad_bias == nullptr) + this->in_grad_bias = std::shared_ptr( + new mkldnn::memory(ipBwdWeights_pd.diff_bias_primitive_desc(), + in_grad_bias.get_data_handle())); + else + this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); + + if (this->ipBwdWeight == nullptr) + this->ipBwdWeight = std::shared_ptr( + new mkldnn::inner_product_backward_weights( + this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight, *this->in_grad_bias)); + } else { + if (this->ipBwdWeight == nullptr) + this->ipBwdWeight = std::shared_ptr( + new mkldnn::inner_product_backward_weights( + this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight)); + } + } + + const mkldnn::inner_product_backward_data &GetBwdData() const { + return *ipBwdData; + } + + const mkldnn::inner_product_backward_weights &GetBwdWeights() const { + return *ipBwdWeight; + } +}; + +typedef ParamOpSign MKLDNNFullyconSignature; + +static inline MKLDNNFullyConnectBackward &GetFCBwd( + const FullyConnectedParam ¶m, const NDArray &data, + const NDArray &weight, const std::vector &in_grad, + const NDArray &out_grad, const std::vector &req, + const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map + bwdDatas; +#else + static MX_THREAD_LOCAL std::unordered_map + bwdDatas; +#endif + MKLDNNFullyconSignature key(param); + key.AddSign(data); + key.AddSign(weight); + key.AddSign(in_grad); + key.AddSign(out_grad); + + auto it = bwdDatas.find(key); + if (it == bwdDatas.end()) { + MKLDNNFullyConnectBackward bwdData(param, data, weight, in_grad, out_grad, + req, ipFwd_pd); + auto ins_ret = bwdDatas.insert( + std::pair( + key, bwdData)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -161,41 +399,34 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; + MKLDNNFullyConnectBackward &FCBwd = + GetFCBwd(param, data, weight, in_grad, out_grad, req, ipFwd_pd); if (req[fullc::kData]) { - mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData( - data, weight, out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdData_pd.diff_dst_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); + FCBwd.ipBwdData_pd.diff_dst_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(FCBwd.ipBwdData_pd.weights_primitive_desc()); auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - ipBwdData_pd.diff_src_primitive_desc(), + FCBwd.ipBwdData_pd.diff_src_primitive_desc(), req[fullc::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( - ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); - CommitOutput(in_grad[fullc::kData], in_grad_mem); + FCBwd.SetNewMemData(*out_grad_mem, *weight_mem, *in_grad_mem.second); + MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdData()); } if (req[fullc::kWeight]) { - mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd - = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], - out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], - ipBwdWeights_pd.diff_weights_primitive_desc(), - req[fullc::kWeight]); + FCBwd.ipBwdWeights_pd.diff_dst_primitive_desc()); + auto data_mem = + data.GetMKLDNNDataReorder(FCBwd.ipBwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad( + in_grad[fullc::kWeight], + FCBwd.ipBwdWeights_pd.diff_weights_primitive_desc(), + req[fullc::kWeight]); mkldnn_output_t in_grad_bias; - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); - } else { - in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], - ipBwdWeights_pd.diff_bias_primitive_desc(), + in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], + FCBwd.ipBwdWeights_pd.diff_bias_primitive_desc(), req[fullc::kBias]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, - *in_grad_bias.second)); - } + FCBwd.SetNewWeightMem(param, *data_mem, *out_grad_mem, + *in_grad_weight.second, *in_grad_bias.second); + MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdWeights()); CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } From e9f6a33cab0a6b9a429708cbe42a255ed8b70c54 Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Mon, 9 Jul 2018 12:05:43 +0800 Subject: [PATCH 08/18] remove fc forward and fix indent problem --- src/operator/nn/mkldnn/mkldnn_convolution.cc | 6 +- .../nn/mkldnn/mkldnn_fully_connected.cc | 238 ++++++------------ 2 files changed, 78 insertions(+), 166 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index f6ad02cc0091..2e19d3219abb 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -309,7 +309,7 @@ class MKLDNNConvBackward { } void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight, - const mkldnn::memory &in_grad) { + const mkldnn::memory &in_grad) { if (this->out_grad == nullptr) this->out_grad = std::shared_ptr(new mkldnn::memory( bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); @@ -333,8 +333,8 @@ class MKLDNNConvBackward { } void SetWeightNewMem(const mkldnn::memory &data, - const mkldnn::memory &out_grad, - const mkldnn::memory &in_grad_weight) { + const mkldnn::memory &out_grad, + const mkldnn::memory &in_grad_weight) { if (this->data == nullptr) this->data = std::shared_ptr(new mkldnn::memory( bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 12e75612fe06..12bee8bdd038 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -82,150 +82,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei } } -class MKLDNNFullyConnectForward { - std::shared_ptr data; - std::shared_ptr weight; - std::shared_ptr out; - std::shared_ptr bias; - std::shared_ptr ipFwd; - - public: - mkldnn::inner_product_forward::primitive_desc ipFwd_pd; - - MKLDNNFullyConnectForward(const FullyConnectedParam ¶m, bool is_train, - const NDArray &data, const NDArray &weight, - const NDArray *bias, - const mkldnn::memory::desc &output) - : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {} - - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - - if (this->weight == nullptr) - this->weight = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight->set_data_handle(weight.get_data_handle()); - - if (this->out == nullptr) - this->out = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out->set_data_handle(output.get_data_handle()); - - if (bias != nullptr) { - if (this->bias == nullptr) - this->bias = std::shared_ptr(new mkldnn::memory( - ipFwd_pd.bias_primitive_desc(), bias->get_data_handle())); - else - this->bias->set_data_handle(bias->get_data_handle()); - if (this->ipFwd == nullptr) - this->ipFwd = std::shared_ptr( - new mkldnn::inner_product_forward( - ipFwd_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->weight), - mkldnn::primitive::at(*this->bias), *this->out)); - } else if (this->ipFwd == nullptr) { - this->ipFwd = std::shared_ptr( - new mkldnn::inner_product_forward( - ipFwd_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->weight), *this->out)); - } - } - - const mkldnn::inner_product_forward &GetIpFwd() const { - return *ipFwd; - } -}; - -typedef ParamOpSign MKLDNNFullyconSignature; - -static inline MKLDNNFullyConnectForward &GetFCFwd( - const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight, - const NDArray *bias, const mkldnn::memory::desc &output, - const bool is_train) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fcFwds; -#else - static MX_THREAD_LOCAL std::unordered_map fcFwds; -#endif - const FullyConnectedParam& param = nnvm::get(attrs.parsed); - MKLDNNFullyconSignature key(param); - key.AddSign(data); - key.AddSign(weight); - key.AddSign(is_train); - - if (bias) - key.AddSign(*bias); - - auto it = fcFwds.find(key); - if (it == fcFwds.end()) { - MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias, - output); - auto ins_ret = fcFwds.insert( - std::pair(key, fcFwd)); - CHECK(ins_ret.second); - it = ins_ret.first; - } - return it->second; -} - -void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { - TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]); - const FullyConnectedParam& param = nnvm::get(attrs.parsed); - const TShape& ishape = in_data[fullc::kData].shape(); - const TShape& oshape = out_data[fullc::kOut].shape(); - NDArray weight = in_data[fullc::kWeight]; - NDArray data = in_data[fullc::kData]; - // If the input data is a view of an MKLDNN array, we should create a new - // NDArray with reordered data. - if (data.IsMKLDNNData() && data.IsView()) - data = in_data[fullc::kData].Reorder2Default(); - - auto out_md = GetMemDesc(out_data[fullc::kOut]); - if (data.shape().ndim() != 2 && !param.flatten) { - data = data.MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1), - ishape[ishape.ndim()-1])); - mkldnn::memory::dims out_dims{static_cast(oshape.ProdShape(0, oshape.ndim()-1)), - static_cast(oshape[ishape.ndim()-1])}; - out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), - mkldnn::memory::format::any); - } else if (data.shape().ndim() != 2) { - data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); - mkldnn::memory::dims out_dims{static_cast(oshape[0]), - static_cast(oshape.ProdShape(1, oshape.ndim()))}; - out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), - mkldnn::memory::format::any); - } - MKLDNNFullyConnectForward &FCFwd = - GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias], - out_md, ctx.is_train); - auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc()); - auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], - FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); - if (!param.no_bias) { - auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder( - FCFwd.ipFwd_pd.bias_primitive_desc()); - FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - } else { - FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); - } - MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd()); - CommitOutput(out_data[fullc::kOut], out_mem); - MKLDNNStream::Get()->Submit(); -} - class MKLDNNFullyConnectBackward { std::shared_ptr ipBwdData; std::shared_ptr ipBwdWeight; @@ -369,6 +225,55 @@ static inline MKLDNNFullyConnectBackward &GetFCBwd( return it->second; } +void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]); + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + const TShape& ishape = in_data[fullc::kData].shape(); + const TShape& oshape = out_data[fullc::kOut].shape(); + NDArray weight = in_data[fullc::kWeight]; + NDArray data = in_data[fullc::kData]; + // If the input data is a view of an MKLDNN array, we should create a new + // NDArray with reordered data. + if (data.IsMKLDNNData() && data.IsView()) + data = in_data[fullc::kData].Reorder2Default(); + + auto out_md = GetMemDesc(out_data[fullc::kOut]); + if (data.shape().ndim() != 2 && !param.flatten) { + data = data.MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1), + ishape[ishape.ndim()-1])); + mkldnn::memory::dims out_dims{static_cast(oshape.ProdShape(0, oshape.ndim()-1)), + static_cast(oshape[ishape.ndim()-1])}; + out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), + mkldnn::memory::format::any); + } else if (data.shape().ndim() != 2) { + data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); + mkldnn::memory::dims out_dims{static_cast(oshape[0]), + static_cast(oshape.ProdShape(1, oshape.ndim()))}; + out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), + mkldnn::memory::format::any); + } + + mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, + param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train); + auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc()); + auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], + ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); + if (param.no_bias) { + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward( + ipFwd_pd, *data_mem, *weight_mem, *out_mem.second)); + } else { + auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc()); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd, + *data_mem, *weight_mem, *bias_mem, *out_mem.second)); + } + CommitOutput(out_data[fullc::kOut], out_mem); + MKLDNNStream::Get()->Submit(); +} + void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -399,34 +304,41 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; - MKLDNNFullyConnectBackward &FCBwd = - GetFCBwd(param, data, weight, in_grad, out_grad, req, ipFwd_pd); if (req[fullc::kData]) { + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData( + data, weight, out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - FCBwd.ipBwdData_pd.diff_dst_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(FCBwd.ipBwdData_pd.weights_primitive_desc()); + ipBwdData_pd.diff_dst_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - FCBwd.ipBwdData_pd.diff_src_primitive_desc(), + ipBwdData_pd.diff_src_primitive_desc(), req[fullc::kData]); - FCBwd.SetNewMemData(*out_grad_mem, *weight_mem, *in_grad_mem.second); - MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdData()); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( + ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[fullc::kData], in_grad_mem); } if (req[fullc::kWeight]) { + mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd + = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], + out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - FCBwd.ipBwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = - data.GetMKLDNNDataReorder(FCBwd.ipBwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad( - in_grad[fullc::kWeight], - FCBwd.ipBwdWeights_pd.diff_weights_primitive_desc(), - req[fullc::kWeight]); + ipBwdWeights_pd.diff_dst_primitive_desc()); + auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], + ipBwdWeights_pd.diff_weights_primitive_desc(), + req[fullc::kWeight]); mkldnn_output_t in_grad_bias; - in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], - FCBwd.ipBwdWeights_pd.diff_bias_primitive_desc(), + if (param.no_bias) { + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); + } else { + in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], + ipBwdWeights_pd.diff_bias_primitive_desc(), req[fullc::kBias]); - FCBwd.SetNewWeightMem(param, *data_mem, *out_grad_mem, - *in_grad_weight.second, *in_grad_bias.second); - MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdWeights()); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, + *in_grad_bias.second)); + } CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } From 2f3f4362e4813cc3c6363da0388c66147ec77c0e Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Mon, 9 Jul 2018 12:27:47 +0800 Subject: [PATCH 09/18] remove fc forward and fix convolution indent problem --- .../nn/mkldnn/mkldnn_fully_connected.cc | 45 ++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 12bee8bdd038..57000b8644a5 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -304,41 +304,34 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; + MKLDNNFullyConnectBackward &FCBwd = + GetFCBwd(param, data, weight, in_grad, out_grad, req, ipFwd_pd); if (req[fullc::kData]) { - mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData( - data, weight, out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdData_pd.diff_dst_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); + FCBwd.ipBwdData_pd.diff_dst_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(FCBwd.ipBwdData_pd.weights_primitive_desc()); auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - ipBwdData_pd.diff_src_primitive_desc(), + FCBwd.ipBwdData_pd.diff_src_primitive_desc(), req[fullc::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( - ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); - CommitOutput(in_grad[fullc::kData], in_grad_mem); + FCBwd.SetNewMemData(*out_grad_mem, *weight_mem, *in_grad_mem.second); + MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdData()); } if (req[fullc::kWeight]) { - mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd - = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], - out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], - ipBwdWeights_pd.diff_weights_primitive_desc(), - req[fullc::kWeight]); + FCBwd.ipBwdWeights_pd.diff_dst_primitive_desc()); + auto data_mem = + data.GetMKLDNNDataReorder(FCBwd.ipBwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad( + in_grad[fullc::kWeight], + FCBwd.ipBwdWeights_pd.diff_weights_primitive_desc(), + req[fullc::kWeight]); mkldnn_output_t in_grad_bias; - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); - } else { - in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], - ipBwdWeights_pd.diff_bias_primitive_desc(), + in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], + FCBwd.ipBwdWeights_pd.diff_bias_primitive_desc(), req[fullc::kBias]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, - *in_grad_bias.second)); - } + FCBwd.SetNewWeightMem(param, *data_mem, *out_grad_mem, + *in_grad_weight.second, *in_grad_bias.second); + MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdWeights()); CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } From 315abb80dc75657c4e298e217d737a37f61a97f7 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 9 Jul 2018 13:15:24 +0800 Subject: [PATCH 10/18] Change log level to FATAL for unreachable code in mkldnn_act.cc --- src/operator/nn/mkldnn/mkldnn_act.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index e7fe79d0e5ff..4b8a4e7ebbe0 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( alg, data_md, alpha); return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); }); - LOG(INFO) << "Unsupported data type for MKLDNN activation"; + LOG(FATAL) << "Unsupported data type for MKLDNN activation"; mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc( mkldnn::prop_kind::forward_training, alg, data_md, 0.0); return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); From 21b1a68da2f0dcac98d7fa7dcf43f0f1f3dae3b6 Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Wed, 11 Jul 2018 09:33:02 +0800 Subject: [PATCH 11/18] remove fc forward and fix convolution indent problem --- src/operator/nn/fully_connected.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 48d479ccf60c..de4b66169e88 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -211,11 +211,11 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), out_expected); DispatchMode wanted_mode; -#if 0 +#if 1 // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. // It seems there is a bug. if (dev_mask == mshadow::cpu::kDevMask) - *dispatch_mode = DispatchMode::kFComputeEx; + wanted_mode = DispatchMode::kFComputeEx; else #endif wanted_mode = DispatchMode::kFCompute; From dea6f91082f3958717be474033677d9893d46464 Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Wed, 11 Jul 2018 09:48:10 +0800 Subject: [PATCH 12/18] remove useless hint in fc --- src/operator/nn/fully_connected.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index de4b66169e88..0e51ec7dec7b 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -212,8 +212,6 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, DispatchMode wanted_mode; #if 1 - // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. - // It seems there is a bug. if (dev_mask == mshadow::cpu::kDevMask) wanted_mode = DispatchMode::kFComputeEx; else From f160c113d46310afc02308dfa96cee06277a89b2 Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Fri, 13 Jul 2018 15:54:31 +0800 Subject: [PATCH 13/18] Merge branch 'master' into backward_op_cache --- src/operator/nn/fully_connected.cc | 24 +-- .../nn/mkldnn/mkldnn_fully_connected.cc | 188 +++--------------- 2 files changed, 37 insertions(+), 175 deletions(-) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index fd6fd4c79707..ed950c016dd3 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -209,19 +209,17 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), out_expected); - bool dispatched = false; - if (common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) { - dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, DispatchMode::kFComputeEx); - } - if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) { - dispatched = dispatch_fallback(out_attrs, dispatch_mode); - } - if (!dispatched) { - dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute); - } - return dispatched; + DispatchMode wanted_mode; +#if 0 + // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. + // It seems there is a bug. + if (dev_mask == mshadow::cpu::kDevMask) + *dispatch_mode = DispatchMode::kFComputeEx; + else +#endif + wanted_mode = DispatchMode::kFCompute; + return storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, wanted_mode); } DMLC_REGISTER_PARAMETER(FullyConnectedParam); diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 57000b8644a5..f86f8dbefa2b 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -82,149 +82,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei } } -class MKLDNNFullyConnectBackward { - std::shared_ptr ipBwdData; - std::shared_ptr ipBwdWeight; - std::shared_ptr out_grad; - std::shared_ptr weight; - std::shared_ptr in_grad; - std::shared_ptr data; - std::shared_ptr in_grad_weight; - std::shared_ptr in_grad_bias; - - public: - mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd; - mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd; - - public: - MKLDNNFullyConnectBackward( - const FullyConnectedParam ¶m, const NDArray &data, - const NDArray &weight, const std::vector &in_grad, - const NDArray &out_grad, const std::vector &req, - const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd) - : ipBwdData_pd(GetIpBwdData(data, weight, out_grad, ipFwd_pd)), - ipBwdWeights_pd(GetIPBwdWeights( - data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], - out_grad, ipFwd_pd)) {} - - void SetNewMemData(const mkldnn::memory &out_grad, - const mkldnn::memory &weight, - const mkldnn::memory &in_grad) { - if (this->out_grad == nullptr) - this->out_grad = std::shared_ptr(new mkldnn::memory( - ipBwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->out_grad->set_data_handle(out_grad.get_data_handle()); - - if (this->weight == nullptr) - this->weight = std::shared_ptr(new mkldnn::memory( - ipBwdData_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight->set_data_handle(weight.get_data_handle()); - - if (this->in_grad == nullptr) - this->in_grad = std::shared_ptr(new mkldnn::memory( - ipBwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); - else - this->in_grad->set_data_handle(in_grad.get_data_handle()); - - if (this->ipBwdData == nullptr) - this->ipBwdData = std::shared_ptr( - new mkldnn::inner_product_backward_data( - this->ipBwdData_pd, mkldnn::primitive::at(*this->out_grad), - mkldnn::primitive::at(*this->weight), *this->in_grad)); - } - - void SetNewWeightMem(const FullyConnectedParam ¶m, - const mkldnn::memory &data, - const mkldnn::memory &out_grad, - const mkldnn::memory &in_grad_weight, - const mkldnn::memory &in_grad_bias) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - ipBwdWeights_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - - if (this->out_grad == nullptr) - this->out_grad = std::shared_ptr(new mkldnn::memory( - ipBwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->out_grad->set_data_handle(out_grad.get_data_handle()); - - if (this->in_grad_weight == nullptr) - this->in_grad_weight = std::shared_ptr( - new mkldnn::memory(ipBwdWeights_pd.diff_weights_primitive_desc(), - in_grad_weight.get_data_handle())); - else - this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); - - if (!param.no_bias) { - if (this->in_grad_bias == nullptr) - this->in_grad_bias = std::shared_ptr( - new mkldnn::memory(ipBwdWeights_pd.diff_bias_primitive_desc(), - in_grad_bias.get_data_handle())); - else - this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); - - if (this->ipBwdWeight == nullptr) - this->ipBwdWeight = std::shared_ptr( - new mkldnn::inner_product_backward_weights( - this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight, *this->in_grad_bias)); - } else { - if (this->ipBwdWeight == nullptr) - this->ipBwdWeight = std::shared_ptr( - new mkldnn::inner_product_backward_weights( - this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight)); - } - } - - const mkldnn::inner_product_backward_data &GetBwdData() const { - return *ipBwdData; - } - - const mkldnn::inner_product_backward_weights &GetBwdWeights() const { - return *ipBwdWeight; - } -}; - -typedef ParamOpSign MKLDNNFullyconSignature; - -static inline MKLDNNFullyConnectBackward &GetFCBwd( - const FullyConnectedParam ¶m, const NDArray &data, - const NDArray &weight, const std::vector &in_grad, - const NDArray &out_grad, const std::vector &req, - const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map - bwdDatas; -#else - static MX_THREAD_LOCAL std::unordered_map - bwdDatas; -#endif - MKLDNNFullyconSignature key(param); - key.AddSign(data); - key.AddSign(weight); - key.AddSign(in_grad); - key.AddSign(out_grad); - - auto it = bwdDatas.find(key); - if (it == bwdDatas.end()) { - MKLDNNFullyConnectBackward bwdData(param, data, weight, in_grad, out_grad, - req, ipFwd_pd); - auto ins_ret = bwdDatas.insert( - std::pair( - key, bwdData)); - CHECK(ins_ret.second); - it = ins_ret.first; - } - return it->second; -} - void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -304,34 +161,41 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; - MKLDNNFullyConnectBackward &FCBwd = - GetFCBwd(param, data, weight, in_grad, out_grad, req, ipFwd_pd); if (req[fullc::kData]) { + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData( + data, weight, out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - FCBwd.ipBwdData_pd.diff_dst_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(FCBwd.ipBwdData_pd.weights_primitive_desc()); + ipBwdData_pd.diff_dst_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - FCBwd.ipBwdData_pd.diff_src_primitive_desc(), + ipBwdData_pd.diff_src_primitive_desc(), req[fullc::kData]); - FCBwd.SetNewMemData(*out_grad_mem, *weight_mem, *in_grad_mem.second); - MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdData()); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( + ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[fullc::kData], in_grad_mem); } if (req[fullc::kWeight]) { + mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd + = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], + out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - FCBwd.ipBwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = - data.GetMKLDNNDataReorder(FCBwd.ipBwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad( - in_grad[fullc::kWeight], - FCBwd.ipBwdWeights_pd.diff_weights_primitive_desc(), - req[fullc::kWeight]); + ipBwdWeights_pd.diff_dst_primitive_desc()); + auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], + ipBwdWeights_pd.diff_weights_primitive_desc(), + req[fullc::kWeight]); mkldnn_output_t in_grad_bias; - in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], - FCBwd.ipBwdWeights_pd.diff_bias_primitive_desc(), + if (param.no_bias) { + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); + } else { + in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], + ipBwdWeights_pd.diff_bias_primitive_desc(), req[fullc::kBias]); - FCBwd.SetNewWeightMem(param, *data_mem, *out_grad_mem, - *in_grad_weight.second, *in_grad_bias.second); - MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdWeights()); + MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, + *in_grad_bias.second)); + } CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } From 913a1437d5895acf2e9077e36cb34113c779a91d Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 16 Jul 2018 12:21:05 +0800 Subject: [PATCH 14/18] Empty commit to retrigger the CI. --- src/operator/nn/fully_connected.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 390f0cfdc573..dc392641620b 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -208,7 +208,6 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, uint32_t out_expected = param.no_bias ? 2 : 3; CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), out_expected); - // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. // It seems there is a bug. bool dispatched = false; From 75039e1e67dad6980b6d773dc7cac3ae6ee08012 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 16 Jul 2018 15:33:45 +0800 Subject: [PATCH 15/18] Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc --- src/operator/nn/mkldnn/mkldnn_act.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index f944d66f4494..e656db737bd4 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -194,7 +194,7 @@ static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( fw_pdesc); return bw_pdesc; }); - LOG(INFO) << "Unsupported data type for MKLDNN activation"; + LOG(FATAL) << "Unsupported data type for MKLDNN activation"; mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, alg, data_md, 0.0); mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); From d92915b229da85018b96b9a2bf25c84af2e228f5 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 25 Jul 2018 12:18:06 +0800 Subject: [PATCH 16/18] Fix build issue after code merge. --- src/operator/nn/mkldnn/mkldnn_pooling-inl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 27ffbdf9774a..66679613d3ae 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -94,9 +94,9 @@ class MKLDNNPoolingBwd { bool with_ws); ~MKLDNNPoolingBwd() {} - void SetDataHandle(const mxnet::NDArray *workspace, - const mxnet::NDArray &out_grad, - const mkldnn::memory *diff_src_mem); + void SetNewMem(const mxnet::NDArray *workspace, + const mxnet::NDArray &out_grad, + const mkldnn::memory *diff_src_mem); const mkldnn::pooling_backward &GetBwd(); const mkldnn::pooling_backward::primitive_desc &GetPd(); }; From ae4a749b8c4cba8e4e62a4b8bad170f0f672aadf Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 27 Aug 2018 12:56:33 +0800 Subject: [PATCH 17/18] Fix lint after merge --- src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index e00d02247d57..bc386bedde1a 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -256,7 +256,7 @@ class MKLDNNLRNBwd { } } - void Execute(const NDArray &in_grad, mkldnn_output_t &diff_src_mem_) { + void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) { MKLDNNStream::Get()->RegisterPrim(*(this->bwd)); CommitOutput(in_grad, diff_src_mem_); MKLDNNStream::Get()->Submit(); From c34c6034343923c8ceb32494e75688c266fff678 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 31 Aug 2018 16:17:22 +0800 Subject: [PATCH 18/18] Fix mkldnn act. --- src/operator/nn/mkldnn/mkldnn_act.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index e656db737bd4..c914b38b542a 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -287,6 +287,7 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx in_buffer = in_data.Reorder2Default(); const ActivationParam& param = nnvm::get(attrs.parsed); + TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); auto diff_dst_memory = out_buffer.GetMKLDNNData(); auto input_mem = in_buffer.GetMKLDNNData(); // We need to make sure the two inputs to eltwise_backward has the same memory @@ -296,7 +297,6 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx MKLDNNActBackward &bwd = GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); - TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); mkldnn_output_t diff_src_memory = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);