Skip to content

Commit

Permalink
Fix a bug in deconvolution.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 7, 2017
1 parent beb8505 commit cd53fb4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ void MKLDNNDeconvolution_Forward(const nnvm::NodeAttrs& attrs, const OpContext &
param, in_data[deconv::kData], in_data[deconv::kWeight],
param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]);
auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder(
deconvFwd_pd.diff_src_primitive_desc());
deconvFwd_pd.diff_dst_primitive_desc());
auto weight_mem = GetWeights(in_data[deconv::kWeight],
deconvFwd_pd.weights_primitive_desc(), param.num_group);
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
deconvFwd_pd.diff_dst_primitive_desc(), req[deconv::kOut]);
deconvFwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);

MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_data(
deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second));
Expand Down Expand Up @@ -225,9 +225,9 @@ void MKLDNNDeconvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext
param.no_bias ? nullptr : &inputs[deconv::kWeight + 1],
inputs[deconv::kOut], bwdData_pd);
auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
bwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
bwdWeights_pd.src_primitive_desc());
auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
bwdWeights_pd.diff_dst_primitive_desc());
auto in_grad_weight = CreateMKLDNNMem(in_grad[deconv::kWeight],
bwdWeights_pd.diff_weights_primitive_desc(), req[deconv::kWeight]);
mkldnn_output_t in_grad_bias;
Expand Down

0 comments on commit cd53fb4

Please sign in to comment.