From cd53fb4ce37a182afdb86acc512633244eba3972 Mon Sep 17 00:00:00 2001 From: Da zheng Date: Tue, 7 Nov 2017 04:14:24 +0000 Subject: [PATCH] Fix a bug in deconvolution. --- src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 8a8566432706..7e5daf6ed251 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -176,11 +176,11 @@ void MKLDNNDeconvolution_Forward(const nnvm::NodeAttrs& attrs, const OpContext & param, in_data[deconv::kData], in_data[deconv::kWeight], param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]); auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder( - deconvFwd_pd.diff_src_primitive_desc()); + deconvFwd_pd.diff_dst_primitive_desc()); auto weight_mem = GetWeights(in_data[deconv::kWeight], deconvFwd_pd.weights_primitive_desc(), param.num_group); auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut], - deconvFwd_pd.diff_dst_primitive_desc(), req[deconv::kOut]); + deconvFwd_pd.diff_src_primitive_desc(), req[deconv::kOut]); MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_data( deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second)); @@ -225,9 +225,9 @@ void MKLDNNDeconvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext param.no_bias ? nullptr : &inputs[deconv::kWeight + 1], inputs[deconv::kOut], bwdData_pd); auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( - bwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder( bwdWeights_pd.src_primitive_desc()); + auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder( + bwdWeights_pd.diff_dst_primitive_desc()); auto in_grad_weight = CreateMKLDNNMem(in_grad[deconv::kWeight], bwdWeights_pd.diff_weights_primitive_desc(), req[deconv::kWeight]); mkldnn_output_t in_grad_bias;