diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 5d722581257f..a7e63f6d4139 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -147,10 +147,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - // TODO(rongzha1): disable due to flakiness in cpp test IMPERATIVE.FullyConnectedOp - // Will be fixed when we decide to enable the backward of FC. - bool mkldnn_fc_backward_enable = false; - if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) { + if (SupportMKLDNNFC(inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNFCBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute, attrs, ctx, inputs, req, @@ -232,12 +229,10 @@ static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, uint32_t out_expected = param.no_bias ? 2 : 3; CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), out_expected); - // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. - // It seems there is a bug. bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) { dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute); + dispatch_mode, DispatchMode::kFComputeEx); } if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 1403cd114201..7627d02c4702 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -290,24 +290,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad)); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; - if (req[fullc::kData]) { - mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData( - data, weight, out_grad, fwd_pd); - auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdData_pd.diff_dst_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc()); - auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - ipBwdData_pd.diff_src_desc(), - req[fullc::kData]); - mkldnn_args_map_t args = { - {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, - {MKLDNN_ARG_WEIGHTS, *weight_mem}, - {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second} - }; - - MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args); - CommitOutput(in_grad[fullc::kData], in_grad_mem); - } if (req[fullc::kWeight]) { mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd = GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], @@ -336,6 +318,24 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } + if (req[fullc::kData]) { + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData( + data, weight, out_grad, fwd_pd); + auto out_grad_mem = out_grad.GetMKLDNNDataReorder( + ipBwdData_pd.diff_dst_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc()); + auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], + ipBwdData_pd.diff_src_desc(), + req[fullc::kData]); + mkldnn_args_map_t args = { + {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second} + }; + + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args); + CommitOutput(in_grad[fullc::kData], in_grad_mem); + } MKLDNNStream::Get()->Submit(); }