Skip to content

Commit

Permalink
Improve activation backward (apache#17973)
Browse files Browse the repository at this point in the history
* fix activation backward

* comments

* fix
  • Loading branch information
TaoLv committed May 29, 2020
1 parent 0c6785f commit a077dca
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,21 +249,28 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx
auto input_mem = in_buffer.GetMKLDNNData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_desc() != diff_dst_memory->get_desc())
if (input_mem->get_desc() != diff_dst_memory->get_desc()) {
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
MKLDNNActBackward &bwd =
GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
}

MKLDNNActBackward &bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *input_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
{ MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second },
};
stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *input_mem},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory}};
if (req[0] != kAddTo) {
// req[0] is kWriteTo or kWriteInplace
auto diff_src_memory =
const_cast<NDArray &>(in_grad).CreateMKLDNNData(bwd.bwd_pd.diff_src_desc());
args.insert({MKLDNN_ARG_DIFF_SRC, *diff_src_memory});
stream->RegisterPrimArgs(bwd.GetBwd(), args);
stream->Submit();
} else {
auto diff_src_memory = CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
args.insert({MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second});
stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
}
}

void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
Expand Down

0 comments on commit a077dca

Please sign in to comment.