diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 1ce36303689d..c862607372a9 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -93,6 +93,10 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); +void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); /* For softmax_output */ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index 1235f3c121fc..e96ab6c20ca3 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -42,6 +42,18 @@ static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(bool is_train, return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine); } +static mkldnn::softmax_backward::primitive_desc GetSoftmaxBwdPd( + const mkldnn::memory &diff_mem, + const mkldnn::memory &data_mem, + const int axis, + const mkldnn::softmax_forward::primitive_desc &hint_fwd_pd) { + mkldnn::memory::desc diff_md = diff_mem.get_desc(); + mkldnn::memory::desc data_md = data_mem.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto desc = mkldnn::softmax_backward::desc(diff_md, data_md, axis); + return mkldnn::softmax_backward::primitive_desc(desc, cpu_engine, hint_fwd_pd); +} + bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m, const NDArray &data, @@ -131,6 +143,78 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, stream->Submit(); } +class MKLDNNSoftmaxBwd { + public: + mkldnn::softmax_backward::primitive_desc pd; + + MKLDNNSoftmaxBwd(const mkldnn::memory &diff_mem, + const mkldnn::memory &data_mem, + const int axis, + const mkldnn::softmax_forward::primitive_desc &hint_fwd_pd) : + pd(GetSoftmaxBwdPd(diff_mem, data_mem, axis, hint_fwd_pd)) { + bwd_ = std::make_shared(pd); + } + + const mkldnn::softmax_backward &GetBwd() const { + return *bwd_; + } + + private: + std::shared_ptr bwd_; +}; + +static MKLDNNSoftmaxBwd &GetSoftmaxBwd(const SoftmaxParam ¶m, + const int real_axis, + const std::vector &data, + const std::vector &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + + MKLDNNSoftmaxSignature key(param); + key.AddSign(real_axis); + key.AddSign(data); + key.AddSign(output); + + auto it = bwds.find(key); + if (it == bwds.end()) { + auto diff_mem = data[0].GetMKLDNNData(); + auto data_mem = data[1].GetMKLDNNData(); + auto fwd_pd = GetSoftmaxFwdPd(true, real_axis, *data_mem); + MKLDNNSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd); + it = AddToCache(&bwds, key, bwd); + } + return it->second; +} + +void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + if (req[0] == kNullOp) return; + CHECK_EQ(in_data.size(), 2U); + const SoftmaxParam& param = nnvm::get(attrs.parsed); + int axis = CheckAxis(param.axis, in_data[1].shape().ndim()); + auto diff_mem = in_data[0].GetMKLDNNData(); + auto data_mem = in_data[1].GetMKLDNNData(); + auto bwd = GetSoftmaxBwd(param, axis, in_data, out_data); + + auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]); + MKLDNNStream *stream = MKLDNNStream::Get(); + mkldnn_args_map_t args = { + { MKLDNN_ARG_DST, *data_mem }, + { MKLDNN_ARG_DIFF_DST, *diff_mem }, + { MKLDNN_ARG_DIFF_SRC, *out_mem.second } + }; + + stream->RegisterPrimArgs(bwd.GetBwd(), args); + CommitOutput(out_data[0], out_mem); + stream->Submit(); +} + } // namespace op } // namespace mxnet #endif diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 57edab7037d5..97949ffbc81e 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -41,7 +41,6 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - // It seems MKLDNN softmax doesn't support training. const SoftmaxParam& param = nnvm::get(attrs.parsed); if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); @@ -54,6 +53,23 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, inputs, req, outputs); } +static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNSoftmaxBackward, attrs, ctx, inputs, req, outputs); + auto fn = SoftmaxGradCompute; + MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(SoftmaxGradCompute, attrs, ctx, + inputs, req, outputs); +} + inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -72,6 +88,23 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs, return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); } + +inline static bool SoftmaxGradStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + if (param.use_length.value() || softmax_has_dtype_override(attrs)) { + auto& out_stype = out_attrs->at(0); + return storage_type_assign(&out_stype, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} #endif @@ -147,8 +180,12 @@ NNVM_REGISTER_OP(_backward_softmax) .set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) .add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") .set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", SoftmaxGradComputeExCPU) +.set_attr("FInferStorageType", SoftmaxGradStorageType) +#endif .set_attr("FCompute", SoftmaxGradCompute); - } // namespace op } // namespace mxnet