diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index eb881d29abd1..17a485c53f65 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -213,10 +213,9 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, uint32_t out_expected = param.no_bias ? 2 : 3; CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), out_expected); - - bool dispatched = false; // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. // It seems there is a bug. + bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) { dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 587cf930920e..410840331912 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -116,6 +116,8 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs, MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs); MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]); MKLDNN_OPCHECK_RUN(LRNCompute, attrs, ctx, inputs, req, outputs); + // Copy outputs[1] from opcheck reference as backward check needs it. + MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector{1}); return; } FallBackCompute(LRNCompute, attrs, ctx, inputs, req, outputs); diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index 744fed2c299f..c914b38b542a 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( alg, data_md, alpha); return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); }); - LOG(INFO) << "Unsupported data type for MKLDNN activation"; + LOG(FATAL) << "Unsupported data type for MKLDNN activation"; mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc( mkldnn::prop_kind::forward_training, alg, data_md, 0.0); return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); @@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, stream->Submit(); } +static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( + const ActivationParam ¶m, const mkldnn::memory &input_mem, + const mkldnn::memory &diff_dst_memory, int dtype) { + mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc(); + auto cpu_engine = data_mpd.get_engine(); + auto alg = GetMKLDNNActAlgo(param); + + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType alpha = 0; + mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, + alg, data_md, alpha); + mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); + mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, + fw_pdesc); + return bw_pdesc; + }); + LOG(FATAL) << "Unsupported data type for MKLDNN activation"; + mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, + alg, data_md, 0.0); + mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0); + mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, + fw_pdesc); + return bw_pdesc; +} + +class MKLDNNActBackward { + std::shared_ptr bwd; + std::shared_ptr data; + std::shared_ptr diff_dst_memory; + std::shared_ptr diff_src_memory; + + public: + const mkldnn::eltwise_backward::primitive_desc pd; + + explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data, + const mkldnn::memory &mem, + const mkldnn::memory &diff_dst_memory) + : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {} + + void SetNewMem(const mkldnn::memory &data, + const mkldnn::memory &diff_dst_memory, + const mkldnn::memory &diff_src_memory) { + if (this->bwd != nullptr) { + this->data->set_data_handle(data.get_data_handle()); + this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle()); + this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle()); + } else { + this->data = std::shared_ptr(new mkldnn::memory( + data.get_primitive_desc(), data.get_data_handle())); + this->diff_dst_memory = std::shared_ptr( + new mkldnn::memory(diff_dst_memory.get_primitive_desc(), + diff_dst_memory.get_data_handle())); + this->diff_src_memory = std::shared_ptr( + new mkldnn::memory(diff_src_memory.get_primitive_desc(), + diff_src_memory.get_data_handle())); + this->bwd = std::shared_ptr( + new mkldnn::eltwise_backward( + this->pd, mkldnn::primitive::at(*this->data), + *this->diff_dst_memory, *this->diff_src_memory)); + } + } + + const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; } +}; + +static inline MKLDNNActBackward &GetActBackward(const ActivationParam ¶m, + const OpContext &ctx, + const NDArray &in_data, + const NDArray &out_grad, + const mkldnn::memory &in_mem) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + MKLDNNActSignature key(param); + key.AddSign(in_data); + key.AddSign(out_grad); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData()); + auto ins_ret = + bwds.insert(std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + // For backward relu activation, it's okay to pass "out_data" as "in_data" to this // function, since the computation only involes non-zeros. void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -200,30 +294,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx // descriptor. Otherwise, the perf will suffer. if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc(); - auto cpu_engine = data_mpd.get_engine(); - + MKLDNNActBackward &bwd = + GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); - auto alg = GetMKLDNNActAlgo(param); - mkldnn_output_t diff_src_memory; - - MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, { - DType alpha = 0; - mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, - alg, data_md, alpha); - mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); - mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); - mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, - fw_pdesc); - - diff_src_memory = CreateMKLDNNMem(in_grad, - bw_pdesc.diff_src_primitive_desc(), req); - stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem, - *diff_dst_memory, - *diff_src_memory.second)); - }); + mkldnn_output_t diff_src_memory = + CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); + bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); + stream->RegisterPrim(bwd.GetBwd()); CommitOutput(in_grad, diff_src_memory); stream->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 273afcd32dc7..4d0b8d062afa 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -497,6 +497,9 @@ class OpCheck { const std::vector &inputs_, const std::vector &req, const std::vector &outputs_); + + void CopyResult(const std::vector &outputs_, + const std::vector& indice); }; bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, @@ -513,6 +516,8 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, #define MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs) \ if (debug) check.Run(fn, attrs, ctx, inputs, req, outputs); +#define MKLDNN_OPCHECK_COPY_RESULT(outputs, indice) \ + if (debug) check.CopyResult(outputs, indice); } // namespace mxnet #endif diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index f3facd966aa7..029f23bd8f5e 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -525,6 +525,17 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs, } } +void OpCheck::CopyResult(const std::vector &outputs_, + const std::vector &indice) { + CHECK(!MKLDNNStream::Get()->HasOps()); + auto non_const_outputs_ = const_cast &>(outputs_); + for (auto i = indice.begin(); i != indice.end(); ++i) { + auto mem = outputs[*i].GetMKLDNNData(); + non_const_outputs_[*i].CopyFrom(*mem); + } + MKLDNNStream::Get()->Submit(); +} + bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, const int dev_mask, bool support_mkldnn, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 9046836e8e75..496ff99f4ee9 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -290,6 +290,84 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, } } +class MKLDNNBNBackward { + std::shared_ptr bwd; + std::shared_ptr data_m; + std::shared_ptr diff_m; + std::shared_ptr gradi_m; + std::shared_ptr mean_m; + std::shared_ptr var_m; + const std::shared_ptr weight_m; + const std::shared_ptr gradw_m; + + public: + const t_bn_b_pdesc pd; + + explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd) + : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())), + gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())), + pd(_pd) {} + + const mkldnn::memory &GetWeight() const { return *weight_m; } + + const mkldnn::memory &GetGradw() const { return *gradw_m; } + + void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff, + const NDArray &mean, const mkldnn::memory &var, + const mkldnn::memory &gradi) { + auto mean_ptr = mean.data().dptr_; + if (bwd == nullptr) { + data_m.reset(new mkldnn::memory(data.get_primitive_desc(), + data.get_data_handle())); + diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(), + diff.get_data_handle())); + gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(), + gradi.get_data_handle())); + mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr)); + var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(), + var.get_data_handle())); + bwd.reset(new mkldnn::batch_normalization_backward( + pd, *data_m, mkldnn::primitive::at(*mean_m), + mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m, + *gradw_m)); + } else { + data_m->set_data_handle(data.get_data_handle()); + diff_m->set_data_handle(diff.get_data_handle()); + gradi_m->set_data_handle(gradi.get_data_handle()); + mean_m->set_data_handle(mean_ptr); + var_m->set_data_handle(var.get_data_handle()); + } + } + + const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; } +}; + +template +static MKLDNNBNBackward &GetBNBackward( + const BatchNormParam ¶m, const OpContext &ctx, const NDArray &in_data, + const mkldnn::memory &in_mem, const NDArray &diff_data, + const mkldnn::memory &diff_mem, unsigned flags) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + MKLDNNBNSignature key(param); + key.AddSign(in_data); + key.AddSign(diff_data); + + auto it = bwds.find(key); + if (it == bwds.end()) { + auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags); + MKLDNNBNBackward bwd(bwd_pd); + auto ins_ret = + bwds.insert(std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + template void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_grad, @@ -326,17 +404,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc()); else if (diff.IsDefaultData()) diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc()); - auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags); + auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc()); if (flags & use_scale_shift) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; - // TODO(tao): how to reuse this memory? - std::shared_ptr weight_mem( - new mkldnn::memory(bwd_pd.weights_primitive_desc())); - - DType* weight_buf = reinterpret_cast(weight_mem->get_data_handle()); + DType *weight_buf = reinterpret_cast(bwd.GetWeight().get_data_handle()); nnvm::dim_t channels_ = data.shape()[1]; for (int i = 0; i < channels_; i++) { if (!param.fix_gamma) @@ -349,15 +423,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, weight_buf[channels_ + i] = (beta.data().dptr())[i]; // bias } - std::shared_ptr gradw_mem( - new mkldnn::memory(bwd_pd.diff_weights_primitive_desc())); // training but no input mean and variance if (ctx.is_train && !param.use_global_stats) { DType* moving_mean_ptr = reinterpret_cast(moving_mean.data().dptr()); DType* moving_var_ptr = reinterpret_cast(moving_var.data().dptr()); DType* out_mean_ptr = reinterpret_cast(out_mean.data().dptr()); DType* out_var_ptr = reinterpret_cast(out_var.data().dptr()); - mkldnn::memory var_mem(bwd_pd.variance_primitive_desc()); + mkldnn::memory var_mem(bwd.pd.variance_primitive_desc()); DType *tmp_var_ptr = reinterpret_cast(var_mem.get_data_handle()); DType minus_mom = (1.0f - param.momentum); @@ -369,45 +441,18 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom; } - - std::shared_ptr out_mean_mem( - new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr)); - std::shared_ptr out_var_mem( - new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr)); - - auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd, - *data_mem, - mkldnn::primitive::at(*out_mean_mem), - mkldnn::primitive::at(var_mem), - *diff_mem, - *weight_mem, - *gradi_mem, - *gradw_mem); - - MKLDNNStream::Get()->RegisterPrim(bn_bwd); + bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem); + MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); MKLDNNStream::Get()->Submit(); } else { - std::shared_ptr imean_mem( - new mkldnn::memory(bwd_pd.mean_primitive_desc(), - moving_mean.data().dptr())); - std::shared_ptr ivar_mem( - new mkldnn::memory(bwd_pd.variance_primitive_desc(), - moving_var.data().dptr())); - auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd, - *data_mem, - mkldnn::primitive::at(*imean_mem), - mkldnn::primitive::at(*ivar_mem), - *diff_mem, - *weight_mem, - *gradi_mem, - *gradw_mem); - - MKLDNNStream::Get()->RegisterPrim(bn_bwd); + bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean, + *moving_var.GetMKLDNNData(), *gradi_mem); + MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); MKLDNNStream::Get()->Submit(); } // copy data from gradw_mem to in_grad[1] and in_grad[2] - DType* gw_buf = reinterpret_cast(gradw_mem->get_data_handle()); + DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); for (int i = 0; i < channels_; i++) { if (!param.fix_gamma) (in_grad[1].data().dptr())[i] = gw_buf[i]; diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index cf04ea8da3d7..2e19d3219abb 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -283,6 +283,157 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx MKLDNNStream::Get()->Submit(); } +class MKLDNNConvBackward { + std::shared_ptr bwd_data; + std::shared_ptr bwd_weight; + // conv::kData + std::shared_ptr out_grad; + std::shared_ptr in_grad; + std::shared_ptr weight; + // conv::kWeight + std::shared_ptr data; + std::shared_ptr output; + std::shared_ptr in_grad_weight; + std::shared_ptr in_grad_bias; + + public: + mkldnn::convolution_backward_data::primitive_desc bwdData_pd; + mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd; + + MKLDNNConvBackward( + const ConvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray *bias, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &fwd_pd): + bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)), + bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) { + } + + void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight, + const mkldnn::memory &in_grad) { + if (this->out_grad == nullptr) + this->out_grad = std::shared_ptr(new mkldnn::memory( + bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->out_grad->set_data_handle(out_grad.get_data_handle()); + if (this->in_grad == nullptr) + this->in_grad = std::shared_ptr(new mkldnn::memory( + bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); + else + this->in_grad->set_data_handle(in_grad.get_data_handle()); + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + bwdData_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + if (this->bwd_data == nullptr) + this->bwd_data = std::shared_ptr( + new mkldnn::convolution_backward_data( + this->bwdData_pd, mkldnn::primitive::at(*this->out_grad), + mkldnn::primitive::at(*this->weight), *this->in_grad)); + } + +void SetWeightNewMem(const mkldnn::memory &data, + const mkldnn::memory &out_grad, + const mkldnn::memory &in_grad_weight) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + if (this->output == nullptr) + this->output = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->output->set_data_handle(out_grad.get_data_handle()); + if (this->in_grad_weight == nullptr) + this->in_grad_weight = std::shared_ptr( + new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), + in_grad_weight.get_data_handle())); + else + this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); + + if (this->bwd_weight == nullptr) + this->bwd_weight = std::shared_ptr( + new mkldnn::convolution_backward_weights( + this->bwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->output), *this->in_grad_weight)); + } + + void SetWeightNewMem(const mkldnn::memory &data, + const mkldnn::memory &out_grad, + const mkldnn::memory &in_grad_weight, + const mkldnn::memory &in_grad_bias) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + if (this->output == nullptr) + this->output = std::shared_ptr(new mkldnn::memory( + bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->output->set_data_handle(out_grad.get_data_handle()); + if (this->in_grad_weight == nullptr) + this->in_grad_weight = std::shared_ptr( + new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), + in_grad_weight.get_data_handle())); + else + this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); + + if (this->in_grad_bias == nullptr) + this->in_grad_bias = std::shared_ptr( + new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(), + in_grad_bias.get_data_handle())); + else + this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); + if (this->bwd_weight == nullptr) + this->bwd_weight = std::shared_ptr( + new mkldnn::convolution_backward_weights( + this->bwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->output), *this->in_grad_weight, + *this->in_grad_bias)); + } + + const mkldnn::convolution_backward_data &GetBwdData() const { + return *bwd_data; + } + + const mkldnn::convolution_backward_weights &GetBwdWeights() const { + return *bwd_weight; + } +}; + +static inline MKLDNNConvBackward &GetConvBwd( + const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &fwd_pd) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + const ConvolutionParam& param = nnvm::get(attrs.parsed); + MKLDNNConvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + if (bias) + key.AddSign(*bias); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd); + auto ins_ret = bwds.insert( + std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& inputs, const std::vector& req, @@ -295,44 +446,45 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]); CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; - mkldnn::convolution_backward_data::primitive_desc bwdData_pd - = GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], - inputs[conv::kOut], fwd_pd); + MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1], + inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd); auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( - bwdData_pd.diff_dst_primitive_desc()); + convBwd.bwdData_pd.diff_dst_primitive_desc()); if (req[conv::kData]) { auto weight_mem = GetWeights(inputs[conv::kWeight + 1], - bwdData_pd.weights_primitive_desc(), param.num_group); + convBwd.bwdData_pd.weights_primitive_desc(), param.num_group); auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], - bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd, - *out_grad_mem, *weight_mem, *in_grad_mem.second)); + convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); + convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); + MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData()); CommitOutput(in_grad[conv::kData], in_grad_mem); } if (req[conv::kWeight]) { - mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd - = GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], - param.no_bias ? nullptr : &inputs[conv::kBias + 1], - inputs[conv::kOut], fwd_pd); - if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc()) + MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1], + inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1], + inputs[conv::kOut], fwd_pd); + if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() != + convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()) out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( - bwdWeights_pd.diff_dst_primitive_desc()); + convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()); auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder( - bwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight], - bwdWeights_pd.diff_weights_primitive_desc(), - req[conv::kWeight]); + convBwdWeight.bwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad( + in_grad[conv::kWeight], + convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(), + req[conv::kWeight]); mkldnn_output_t in_grad_bias; if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); + convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, + *in_grad_weight.second); + MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); } else { - in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], - bwdWeights_pd.diff_bias_primitive_desc(), - req[conv::kBias]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, - *in_grad_bias.second)); + in_grad_bias = CreateMKLDNNMem( + in_grad[conv::kBias], + convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]); + convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, + *in_grad_weight.second, *in_grad_bias.second); + MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); CommitOutput(in_grad[conv::kBias], in_grad_bias); } CommitOutput(in_grad[conv::kWeight], in_grad_weight); @@ -342,5 +494,4 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct } // namespace op } // namespace mxnet - #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 7f3676a70dd0..54d4f6708524 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -93,9 +93,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( return mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd); } -static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( - const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights, - bool has_bias, const NDArray &output) { +static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, bool has_bias, const NDArray &output) { auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); @@ -116,9 +116,10 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( strides, padding, dilate); } -static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( - const DeconvolutionParam& param, const NDArray &data, const NDArray &weights, - bool has_bias, const NDArray &output, +static mkldnn::convolution_backward_weights::primitive_desc +GetDeconvBwdWeightsImpl( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, bool has_bias, const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) { auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.num_group); @@ -308,55 +309,203 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data); } -void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +class MKLDNNDeconvBackwardData { + std::shared_ptr bwd; + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr out; + + public: + const mkldnn::convolution_forward::primitive_desc pd; + + MKLDNNDeconvBackwardData(const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output) + : pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) { + } + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory &output) { + if (bwd == nullptr) { + this->data = std::shared_ptr( + new mkldnn::memory(pd.src_primitive_desc(), data.get_data_handle())); + this->weight = std::shared_ptr( + new mkldnn::memory(pd.weights_primitive_desc(), weight.get_data_handle())); + this->out = std::shared_ptr( + new mkldnn::memory(pd.dst_primitive_desc(), output.get_data_handle())); + bwd = std::shared_ptr( + new mkldnn::convolution_forward(pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + *this->out)); + } else { + this->data->set_data_handle(data.get_data_handle()); + this->weight->set_data_handle(weight.get_data_handle()); + this->out->set_data_handle(output.get_data_handle()); + } + } + + const mkldnn::convolution_forward &GetBwd() const { return *bwd; } +}; + +typedef ParamOpSign MKLDNNDeconvSignature; + +static inline MKLDNNDeconvBackwardData &GetDeconvBwdData( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map + bwds; +#else + static MX_THREAD_LOCAL std::unordered_map + bwds; +#endif + MKLDNNDeconvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNDeconvBackwardData bwd(param, data, weights, output); + auto ins_ret = bwds.insert( + std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +class MKLDNNDeconvBackwardWeights { + std::shared_ptr bwd; + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr out; + + public: + const mkldnn::convolution_backward_weights::primitive_desc pd; + + MKLDNNDeconvBackwardWeights( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) + : pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output, + bwd_data_pd)) {} + + void SetNewMem( + const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory &output, + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) { + if (bwd == nullptr) { + this->data = std::shared_ptr(new mkldnn::memory( + bwd_data_pd.src_primitive_desc(), data.get_data_handle())); + this->weight = std::shared_ptr(new mkldnn::memory( + bwd_data_pd.weights_primitive_desc(), weight.get_data_handle())); + this->out = std::shared_ptr(new mkldnn::memory( + bwd_data_pd.dst_primitive_desc(), output.get_data_handle())); + bwd = std::shared_ptr( + new mkldnn::convolution_backward_weights(pd, *this->data, + *this->weight, *this->out)); + } else { + this->data->set_data_handle(data.get_data_handle()); + this->weight->set_data_handle(weight.get_data_handle()); + this->out->set_data_handle(output.get_data_handle()); + } + } + + const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; } +}; + +static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map + bwds; +#else + static MX_THREAD_LOCAL std::unordered_map + bwds; +#endif + MKLDNNDeconvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + + auto it = bwds.find(key); + if (it == bwds.end()) { + MKLDNNDeconvBackwardWeights bwd(param, data, weights, output, bwd_data_pd); + auto ins_ret = bwds.insert( + std::pair(key, + bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); const std::vector &in_grad = outputs; - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace"; - mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData( - param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], false, - inputs[deconv::kOut]); + const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); + CHECK_NE(req[deconv::kWeight], kWriteInplace) + << "cannot write weight inplace"; + MKLDNNDeconvBackwardData &bwd_data = + GetDeconvBwdData(param, inputs[deconv::kData + 1], + inputs[deconv::kWeight + 1], inputs[deconv::kOut]); auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( - bwdData_pd.src_primitive_desc()); + bwd_data.pd.src_primitive_desc()); if (req[deconv::kData]) { - auto weight_mem = GetWeights(inputs[deconv::kWeight + 1], - bwdData_pd.weights_primitive_desc(), - param.num_group); - auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData], - bwdData_pd.dst_primitive_desc(), - req[deconv::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_forward(bwdData_pd, - *out_grad_mem, *weight_mem, *in_grad_mem.second)); + auto weight_mem = + GetWeights(inputs[deconv::kWeight + 1], + bwd_data.pd.weights_primitive_desc(), param.num_group); + auto in_grad_mem = + CreateMKLDNNMem(in_grad[deconv::kData], + bwd_data.pd.dst_primitive_desc(), req[deconv::kData]); + bwd_data.SetNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); + MKLDNNStream::Get()->RegisterPrim(bwd_data.GetBwd()); CommitOutput(in_grad[deconv::kData], in_grad_mem); } if (req[deconv::kWeight]) { - mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd - = GetDeconvBwdWeights(param, inputs[deconv::kData + 1], - inputs[deconv::kWeight + 1], false, inputs[deconv::kOut], bwdData_pd); - if (bwdData_pd.src_primitive_desc() != bwdWeights_pd.src_primitive_desc()) + MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights( + param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], + inputs[deconv::kOut], bwd_data.pd); + if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc()) out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( - bwdWeights_pd.src_primitive_desc()); + bwd_weights.pd.src_primitive_desc()); auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder( - bwdWeights_pd.diff_dst_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[deconv::kWeight], - bwdWeights_pd.diff_weights_primitive_desc(), - req[deconv::kWeight]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second)); + bwd_weights.pd.diff_dst_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad( + in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(), + req[deconv::kWeight]); + bwd_weights.SetNewMem(*out_grad_mem, *data_mem, *in_grad_weight.second, bwd_data.pd); + MKLDNNStream::Get()->RegisterPrim(bwd_weights.GetBwd()); CommitOutput(in_grad[deconv::kWeight], in_grad_weight); } MKLDNNStream::Get()->Submit(); if (!param.no_bias) { typedef float DType; Stream *s = ctx.get_stream(); - Tensor gbias = in_grad[deconv::kBias].data().get(s); + Tensor gbias = + in_grad[deconv::kBias].data().get(s); // If there is bias, the out grad has already been converted to the default // format, so this shouldn't cause any performance issues. - Tensor grad = inputs[deconv::kOut].data().get(s); - Assign(gbias, req[deconv::kBias], mshadow::expr::sumall_except_dim<1>(grad)); + Tensor grad = + inputs[deconv::kOut].data().get(s); + Assign(gbias, req[deconv::kBias], + mshadow::expr::sumall_except_dim<1>(grad)); } } diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index 4b179a7fbc98..bc386bedde1a 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -18,7 +18,7 @@ */ /*! - * \file mkldnn_lrn-inl.h + * \file mkldnn_lrn-inl.h * \brief * \Author: Patric Zhao, patric.zhao@intel.com */ @@ -40,9 +40,8 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { return algorithm::lrn_across_channels; } -inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam ¶m, - const bool is_train, - const memory::desc &src_md) { +inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc( + const LRNParam ¶m, const bool is_train, const memory::desc &src_md) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); const algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; @@ -59,11 +58,10 @@ inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam ¶m, return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine); } -inline mkldnn::lrn_backward::primitive_desc -GetLRNBwd(const LRNParam ¶m, - const mkldnn::memory::desc &data_in_md, - const mkldnn::memory::desc &diff_md, - const lrn_forward::primitive_desc &lrnFwd_desc) { +inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc( + const LRNParam ¶m, const mkldnn::memory::desc &data_in_md, + const mkldnn::memory::desc &diff_md, + const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); const algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; @@ -96,8 +94,15 @@ class MKLDNNLRNFwd { const NDArray &output, const OpReqType req); + void SetNewMem(const NDArray &in_data, + const mkldnn::memory *out_mem); + void Execute(const NDArray &out_data); + mkldnn::lrn_forward &GetFwd(); + + const mkldnn::memory *GetWs(); + private: std::shared_ptr fwd; std::shared_ptr in_mem; @@ -113,15 +118,17 @@ class MKLDNNLRNFwd { void MKLDNNLRNFwd::_Init(const LRNParam ¶m, bool is_train, const NDArray &in_data) { - mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_primitive_desc().desc(); - lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md); + mkldnn::memory::desc in_data_md = + in_data.GetMKLDNNData()->get_primitive_desc().desc(); + mkldnn::lrn_forward::primitive_desc fwd_pd = + GetLRNFwdDesc(param, is_train, in_data_md); this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData() ->get_primitive_desc())); this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc())); if (is_train) { - // If it's training, we have to create a workspace memory. Otherwise, MKLDNN - // will have segmentation fault. + // If it's training, we have to create a workspace memory. Otherwise, + // MKLDNN will have segmentation fault. ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc())); this->fwd = std::shared_ptr( new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem), @@ -142,11 +149,22 @@ void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data, this->out_mem->set_data_handle(output_mem_t.second->get_data_handle()); } +void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data, + const mkldnn::memory *out_mem) { + const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData(); + this->in_mem->set_data_handle(in_data_mem->get_data_handle()); + this->out_mem->set_data_handle(out_mem->get_data_handle()); +} + void MKLDNNLRNFwd::Execute(const NDArray &out_data) { MKLDNNStream::Get()->RegisterPrim(*(this->fwd)); CommitOutput(out_data, output_mem_t); MKLDNNStream::Get()->Submit(); } + +mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; } + +const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); } // End of LRN Class and its functions static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, @@ -161,16 +179,10 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, MKLDNNLRNFwd, OpHash> lrn_fwds; #endif - auto alg_ = algorithm::lrn_across_channels; - auto kind_ = prop_kind::forward_training; - if (ctx.is_train) { - kind_ = prop_kind::forward_training; - } else { - kind_ = prop_kind::forward_scoring; - } + auto kind_ = + ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring; MKLDNNLRNSignature key(param); - key.AddSign(alg_); key.AddSign(kind_); key.AddSign(in_data); @@ -185,10 +197,8 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, return it->second; } -void MKLDNNLRNForward(const OpContext &ctx, - const LRNParam ¶m, - const NDArray &in_data, - const OpReqType req, +void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m, + const NDArray &in_data, const OpReqType req, const NDArray &out_data) { auto in_buffer = in_data; if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) @@ -198,6 +208,90 @@ void MKLDNNLRNForward(const OpContext &ctx, fwd.Execute(out_data); } +// LRN Backward Class +class MKLDNNLRNBwd { + std::shared_ptr bwd; + std::shared_ptr in_data_mem; + std::shared_ptr diff_dst_mem; + std::shared_ptr ws_mem; + std::shared_ptr diff_src_mem; + + public: + const mkldnn::lrn_forward::primitive_desc fwd_pd; + const mkldnn::lrn_backward::primitive_desc bwd_pd; + + ~MKLDNNLRNBwd() {} + + MKLDNNLRNBwd(const LRNParam ¶m, const mkldnn::memory::desc in_data_md, + const mkldnn::memory::desc diff_md) + : fwd_pd(GetLRNFwdDesc(param, true, in_data_md)), + bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {} + + void SetNewMem(const NDArray &in_data, const NDArray &out_grad, + const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) { + if (bwd == nullptr) { + this->in_data_mem.reset( + new mkldnn::memory(this->fwd_pd.src_primitive_desc(), + in_data.GetMKLDNNData()->get_data_handle())); + this->diff_dst_mem.reset( + new mkldnn::memory(this->fwd_pd.dst_primitive_desc(), + out_grad.GetMKLDNNData()->get_data_handle())); + this->ws_mem.reset( + new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(), + ws->get_data_handle())); + this->diff_src_mem.reset( + new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(), + diff_src_mem->get_data_handle())); + this->bwd.reset(new mkldnn::lrn_backward( + this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem), + mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem, + *this->diff_src_mem)); + } else { + this->in_data_mem->set_data_handle( + in_data.GetMKLDNNData()->get_data_handle()); + this->diff_dst_mem->set_data_handle( + out_grad.GetMKLDNNData()->get_data_handle()); + this->ws_mem->set_data_handle(ws->get_data_handle()); + this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle()); + } + } + + void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) { + MKLDNNStream::Get()->RegisterPrim(*(this->bwd)); + CommitOutput(in_grad, diff_src_mem_); + MKLDNNStream::Get()->Submit(); + } +}; // End of LRN Class + +static MKLDNNLRNBwd &GetLRNBwd(const LRNParam ¶m, const NDArray &in_data, + const NDArray &in_grad, const NDArray &out_grad) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local + std::unordered_map lrn_bwds; +#else + static MX_THREAD_LOCAL + std::unordered_map lrn_bwds; +#endif + MKLDNNLRNSignature key(param); + key.AddSign(in_data); + key.AddSign(in_grad); + key.AddSign(out_grad); + + auto it = lrn_bwds.find(key); + if (it == lrn_bwds.end()) { + const mkldnn::memory::desc in_data_md = + in_data.GetMKLDNNData()->get_primitive_desc().desc(); + const mkldnn::memory::desc diff_md = + out_grad.GetMKLDNNData()->get_primitive_desc().desc(); + MKLDNNLRNBwd bwd(param, in_data_md, diff_md); + auto ins_ret = + lrn_bwds.insert(std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, const NDArray &out_grad, const NDArray &in_data, @@ -206,43 +300,27 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, if (req == kNullOp) { return; } - // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when in_data is nchw8c auto in_buffer = in_data; if (in_buffer.IsMKLDNNData()) { in_buffer = in_data.Reorder2Default(); } - + MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad); // Repeat FW for getting workspace - const mkldnn::memory *data_mem = in_buffer.GetMKLDNNData(); - const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc(); - const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, ctx.is_train, - data_md); - // TODO(Patric): To keep the function stateless, we can't pass workspace // from LRN forward to backward. We have to re-compute // LRN forward to get the workspace. // Will refine this code later. - std::shared_ptr ws_mem( - new mkldnn::memory(pdesc_fwd.workspace_primitive_desc())); + MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer); std::shared_ptr dst_temp( - new mkldnn::memory(pdesc_fwd.dst_primitive_desc())); - MKLDNNStream::Get()->RegisterPrim( - lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem), - *ws_mem, *dst_temp)); - - const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData(); - const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc(); - const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, data_md, - diff_md, pdesc_fwd); - mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad, - pdesc_bwd.diff_src_primitive_desc(), req); - - MKLDNNStream::Get()->RegisterPrim( - lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem), - mkldnn::primitive::at(*diff_mem), *ws_mem, *diff_src_mem.second)); - CommitOutput(in_grad, diff_src_mem); - MKLDNNStream::Get()->Submit(); + new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc())); + fwd.SetNewMem(in_buffer, dst_temp.get()); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + + mkldnn_output_t diff_src_mem = + CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req); + bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second); + bwd.Execute(in_grad, diff_src_mem); } } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 5d349d372026..66679613d3ae 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -80,6 +80,27 @@ class MKLDNNPoolingFwd { const int padding_l, const int padding_r); }; +class MKLDNNPoolingBwd { + std::shared_ptr bwd; + std::shared_ptr diff_dst; + std::shared_ptr diff_src; + std::shared_ptr ws; + bool with_workspace; + + public: + const mkldnn::pooling_backward::primitive_desc pd; + + MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc, + bool with_ws); + + ~MKLDNNPoolingBwd() {} + void SetNewMem(const mxnet::NDArray *workspace, + const mxnet::NDArray &out_grad, + const mkldnn::memory *diff_src_mem); + const mkldnn::pooling_backward &GetBwd(); + const mkldnn::pooling_backward::primitive_desc &GetPd(); +}; + inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { return param.kernel.ndim() == 2 && (param.pool_type == pool_enum::kMaxPooling || diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index d8d65badc1c4..1610944304e1 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -134,10 +134,9 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { } } -mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam ¶m, - const bool is_train, - const memory::desc &data_md, - const memory::desc &out_md) { +mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( + const PoolingParam ¶m, const bool is_train, const memory::desc &data_md, + const memory::desc &out_md) { CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; int kernel_h_, kernel_w_; if (param.global_pool) { @@ -255,11 +254,124 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data, const NDArray *workspace) { - auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); + auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); fwd.SetNewMem(in_data, out_data, req, workspace); fwd.Execute(out_data); } +MKLDNNPoolingBwd::MKLDNNPoolingBwd( + const pooling_backward::primitive_desc &pdesc, bool with_ws) + : with_workspace(with_ws), pd(pdesc) {} + +void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace, + const mxnet::NDArray &out_grad, + const mkldnn::memory *diff_src_mem) { + if (bwd == nullptr) { + diff_dst.reset( + new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(), + out_grad.GetMKLDNNData()->get_data_handle())); + diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(), + diff_src_mem->get_data_handle())); + if (with_workspace) { + CHECK(workspace != nullptr); + ws.reset( + new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(), + workspace->GetMKLDNNData()->get_data_handle())); + bwd.reset( + new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src)); + } else { + bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src)); + } + } else { + diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle()); + diff_src->set_data_handle(diff_src_mem->get_data_handle()); + if (with_workspace) { + CHECK(workspace != nullptr); + ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle()); + } + } +} + +const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() { + return *this->bwd; +} + +MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, + const NDArray &in_data, + const NDArray &in_grad, + const NDArray &out_grad) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local + std::unordered_map pooling_bwds; +#else + static MX_THREAD_LOCAL + std::unordered_map pooling_bwds; +#endif + + bool with_workspace = MKLDNNRequireWorkspace(param); + MKLDNNPoolingSignature key(param); + key.AddSign(in_data); + key.AddSign(in_grad); + key.AddSign(out_grad); + + auto it = pooling_bwds.find(key); + if (it == pooling_bwds.end()) { + auto diff_dst_mem = out_grad.GetMKLDNNData(); + auto input_mem = in_data.GetMKLDNNData(); + mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); + const mkldnn::memory::desc data_md = data_mpd.desc(); + const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], + static_cast(out_grad.shape()[2]), + static_cast(out_grad.shape()[3])}; + const memory::desc out_md( + {dims}, static_cast(data_md.data.data_type), + static_cast(data_md.data.format)); + auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md); + + const mkldnn::memory::desc diff_md = + diff_dst_mem->get_primitive_desc().desc(); + const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], + static_cast(in_grad.shape()[2]), + static_cast(in_grad.shape()[3])}; + const memory::desc diff_in_md( + {dims1}, static_cast(diff_md.data.data_type), + static_cast(diff_md.data.format)); + const mkldnn::engine cpu_engine = data_mpd.get_engine(); + const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); + + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + + int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; + int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; + int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + if (param.global_pool) { + pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; + stride_h_ = stride_w_ = 1; + } + + const pooling_backward::desc desc( + alg, diff_in_md, diff_md, {stride_h_, stride_w_}, + {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}, + mkldnn::padding_kind::zero); + const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); + MKLDNNPoolingBwd bwd(pdesc, with_workspace); + auto ins_ret = pooling_bwds.insert( + std::pair(key, bwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &out_grad, const NDArray &in_data, const NDArray *workspace, const OpReqType req, @@ -267,68 +379,14 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, if (req == kNullOp) { return; } - TmpMemMgr::Get()->Init(ctx.requested[0]); - // mkldnn::memory - auto diff_dst_mem = out_grad.GetMKLDNNData(); - auto input_mem = in_data.GetMKLDNNData(); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - const mkldnn::memory::desc data_md = data_mpd.desc(); - const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], - static_cast(out_grad.shape()[2]), - static_cast(out_grad.shape()[3])}; - const memory::desc out_md({dims}, - static_cast(data_md.data.data_type), - static_cast(data_md.data.format)); - auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md); - - const mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc(); - const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], - static_cast(in_grad.shape()[2]), - static_cast(in_grad.shape()[3])}; - const memory::desc diff_in_md( - {dims1}, static_cast(diff_md.data.data_type), - static_cast(diff_md.data.format)); - const mkldnn::engine cpu_engine = data_mpd.get_engine(); - const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); - - int kernel_h_, kernel_w_; - if (param.global_pool) { - kernel_h_ = data_md.data.dims[2]; - kernel_w_ = data_md.data.dims[3]; - } else { - kernel_h_ = param.kernel[0]; - kernel_w_ = param.kernel[1]; - } - - int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; - int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; - int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; - if (param.global_pool) { - pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; - stride_h_ = stride_w_ = 1; - } - - const pooling_backward::desc desc(alg, diff_in_md, diff_md, - {stride_h_, stride_w_}, - {kernel_h_, kernel_w_}, - {pad_t_, pad_l_}, {pad_b_, pad_r_}, - mkldnn::padding_kind::zero); - const pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd); + auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); auto diff_src_mem = - CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req); - - if (MKLDNNRequireWorkspace(param)) { - CHECK(workspace != nullptr); - auto workspace_mem = workspace->GetMKLDNNData(); - MKLDNNStream::Get()->RegisterPrim( - pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem), - *diff_src_mem.second)); - } else { - MKLDNNStream::Get()->RegisterPrim( - pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second)); - } + CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); + + bwd.SetNewMem(workspace, out_grad, diff_src_mem.second); + MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); CommitOutput(in_grad, diff_src_mem); MKLDNNStream::Get()->Submit(); }