diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 18055caaaeba..0e7a05669cdc 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -139,13 +139,6 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, return it->second; } -template -static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, - const OpContext &ctx, const NDArray &in_data, - mkldnn::normalization_flags flags) { - return GetBNForward(param, ctx, in_data.GetMKLDNNData(), flags); -} - template void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -176,8 +169,12 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, aux_states, param, ctx.is_train && !param.use_global_stats); - const NDArray &data = in_data[batchnorm::kData]; - auto &fwd = GetBNForward(param, ctx, data, flags); + + NDArray &data = in_data[batchnorm::kData]; + if (data.IsMKLDNNData() && data.IsView()) + data = data.Reorder2Default(); + auto data_mem = data.GetMKLDNNData(); + auto &fwd = GetBNForward(param, ctx, data_mem, flags); // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -215,7 +212,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } mkldnn_args_map_t net_args; - net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + net_args[MKLDNN_ARG_SRC] = *data_mem; net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; net_args[MKLDNN_ARG_DST] = *out_mem;