Skip to content

Commit

Permalink
change to need_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Nov 28, 2018
1 parent 95e360f commit 30e0636
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0]) && ctx.need_grad) {
if (SupportMKLDNNConv(params, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
out_grad = out_grad.Reorder2Default();

mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(
full_param, ctx.is_train, data, weight, bias, out_grad);
full_param, ctx.need_grad, data, weight, bias, out_grad);
const ConvolutionParam &param = full_param.conv_param;

CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
Expand Down

0 comments on commit 30e0636

Please sign in to comment.