Skip to content

Commit

Permalink
Fix MKLDNN BatchNorm with even number of channels (apache#19150)
Browse files Browse the repository at this point in the history
Even number of channels results in data reordering before batch
norm operation. Therefore, if BatchNorm data array is view of
another array and the data is stored in MKLDNN format, the data
needs to be converted to the default format.
  • Loading branch information
akarbown committed Oct 26, 2020
1 parent 8dc3652 commit 4567680
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,6 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
return it->second;
}

template<typename DType>
static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
const OpContext &ctx, const NDArray &in_data,
mkldnn::normalization_flags flags) {
return GetBNForward<DType>(param, ctx, in_data.GetMKLDNNData(), flags);
}

template <typename DType>
void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -182,8 +175,11 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
aux_states,
ctx.is_train && !param.use_global_stats,
fuse_relu);
const NDArray &data = in_data[batchnorm::kData];
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
NDArray &data = in_data[batchnorm::kData];
if (data.IsMKLDNNData() && data.IsView())
data = data.Reorder2Default();
auto data_mem = data.GetMKLDNNData();
auto &fwd = GetBNForward<DType>(param, ctx, data_mem, flags);

// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
Expand Down Expand Up @@ -221,7 +217,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}

mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
net_args[MKLDNN_ARG_DST] = *out_mem;
if (fuse_relu) {
Expand Down

0 comments on commit 4567680

Please sign in to comment.