diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index 8e2b57781a18..7b266efc2a14 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -92,13 +92,13 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto gz_mem = inputs[0].GetMKLDNNData(); mkldnn::memory::primitive_desc gz_pd = gz_mem->get_primitive_desc(); /* init the offset */ - mkldnn::memory::dims offsets = {0, 0, 0, 0}; + mkldnn::memory::dims offsets(outputs[0].shape().ndim()); + for (auto &v : offsets) { + v = 0; + } + for (int i = 0; i < num_in_data; i++) { - mkldnn::memory::dims diff_src_tz - = {static_cast(outputs[i].shape()[0]), - static_cast(outputs[i].shape()[1]), - static_cast(outputs[i].shape()[2]), - static_cast(outputs[i].shape()[3])}; + mkldnn::memory::dims diff_src_tz(outputs[i].shape().begin(), outputs[i].shape().end()); auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc(); auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]); // create view from gy to gxs[i] diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index 3f3d82020598..96a8afdab6e2 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -42,7 +42,7 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m, mkldnn::memory::dims offsets(N); for (uint32_t i = 0; i < N; ++i) { int s = 0; - if (param.begin[i]) { + if (i < param.begin.ndim() && param.begin[i]) { s = *param.begin[i]; if (s < 0) s += ishape[i]; }