From 9ab4a9b056df2e9ffbca71f8f9908bd5408151b7 Mon Sep 17 00:00:00 2001 From: Anna Karbownik Date: Mon, 5 Oct 2020 10:13:55 +0200 Subject: [PATCH] Fix MKLDNN BatchNorm with even number of channels (#19150) Even number of channels results in data reordering before batch norm operation. Therefore, if BatchNorm data array is view of another array and the data is stored in MKLDNN format, the data needs to be converted to the default format. --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 0a29a6d87de6..2274591be8ad 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -145,13 +145,6 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, return it->second; } -template -static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, - const OpContext &ctx, const NDArray &in_data, - mkldnn::normalization_flags flags) { - return GetBNForward(param, ctx, in_data.GetMKLDNNData(), flags); -} - template void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -182,8 +175,12 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu); - const NDArray &data = in_data[batchnorm::kData]; - auto &fwd = GetBNForward(param, ctx, data, flags); + + NDArray &data = in_data[batchnorm::kData]; + if (data.IsMKLDNNData() && data.IsView()) + data = data.Reorder2Default(); + auto data_mem = data.GetMKLDNNData(); + auto &fwd = GetBNForward(param, ctx, data_mem, flags); // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -221,7 +218,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } mkldnn_args_map_t net_args; - net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + net_args[MKLDNN_ARG_SRC] = *data_mem; net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; net_args[MKLDNN_ARG_DST] = *out_mem; if (fuse_relu) {