From f0059f611aa139bd91606385c20832694973bbfd Mon Sep 17 00:00:00 2001 From: Ciyong Chen Date: Wed, 15 May 2019 21:32:36 +0800 Subject: [PATCH] fix cpp test failure --- src/operator/nn/mkldnn/mkldnn_sum.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc index b5bdd2107ebf..724b8a2613d6 100644 --- a/src/operator/nn/mkldnn/mkldnn_sum.cc +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -141,6 +141,7 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, in_bufs[i] = inputs[i].Reorder2Default(); in_mem = in_bufs[i].GetMKLDNNData(); } else { + in_bufs[i] = inputs[i]; in_mem = inputs[i].GetMKLDNNData(); } mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc(); @@ -148,11 +149,11 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, data_mem.push_back(in_mem); } - MKLDNNSumFwd &fwd = GetSumForward(scales, inputs, data_md); + MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md); mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, - &inputs[0]); + &in_bufs[0]); fwd.SetNewMem(data_mem, *out_mem.second); MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); CommitOutput(out_data, out_mem);