Skip to content

Commit

Permalink
Mkldnn fullyConnect bwd bug fix (apache#16890)
Browse files Browse the repository at this point in the history
* fix mkldnn fc bwd bug due to data inplace

* enable mkldnn fc bwd
  • Loading branch information
rongzha1 authored and pengzhao-intel committed Nov 25, 2019
1 parent 9b49cfe commit 436967b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 25 deletions.
9 changes: 2 additions & 7 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// TODO(rongzha1): disable due to flakiness in cpp test IMPERATIVE.FullyConnectedOp
// Will be fixed when we decide to enable the backward of FC.
bool mkldnn_fc_backward_enable = false;
if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) {
if (SupportMKLDNNFC(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNFCBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down Expand Up @@ -232,12 +229,10 @@ static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), out_expected);
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
// It seems there is a bug.
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
Expand Down
36 changes: 18 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,24 +290,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad));

CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
Expand Down Expand Up @@ -336,6 +318,24 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
MKLDNNStream::Get()->Submit();
}

Expand Down

0 comments on commit 436967b

Please sign in to comment.