diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index de5c0149e612..4654c24547cd 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -249,21 +249,28 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx auto input_mem = in_buffer.GetMKLDNNData(); // We need to make sure the two inputs to eltwise_backward has the same memory // descriptor. Otherwise, the perf will suffer. - if (input_mem->get_desc() != diff_dst_memory->get_desc()) + if (input_mem->get_desc() != diff_dst_memory->get_desc()) { input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc()); - MKLDNNActBackward &bwd = - GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); + } + + MKLDNNActBackward &bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); - mkldnn_output_t diff_src_memory = - CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]); - mkldnn_args_map_t args = { - { MKLDNN_ARG_SRC, *input_mem }, - { MKLDNN_ARG_DIFF_DST, *diff_dst_memory }, - { MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second }, - }; - stream->RegisterPrimArgs(bwd.GetBwd(), args); - CommitOutput(in_grad, diff_src_memory); - stream->Submit(); + mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *input_mem}, + {MKLDNN_ARG_DIFF_DST, *diff_dst_memory}}; + if (req[0] != kAddTo) { + // req[0] is kWriteTo or kWriteInplace + auto diff_src_memory = + const_cast(in_grad).CreateMKLDNNData(bwd.bwd_pd.diff_src_desc()); + args.insert({MKLDNN_ARG_DIFF_SRC, *diff_src_memory}); + stream->RegisterPrimArgs(bwd.GetBwd(), args); + stream->Submit(); + } else { + auto diff_src_memory = CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]); + args.insert({MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second}); + stream->RegisterPrimArgs(bwd.GetBwd(), args); + CommitOutput(in_grad, diff_src_memory); + stream->Submit(); + } } void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,