diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 0a29a6d87de6..75c7c4dbf38a 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -145,13 +145,6 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, return it->second; } -template -static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, - const OpContext &ctx, const NDArray &in_data, - mkldnn::normalization_flags flags) { - return GetBNForward(param, ctx, in_data.GetMKLDNNData(), flags); -} - template void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -182,8 +175,11 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu); - const NDArray &data = in_data[batchnorm::kData]; - auto &fwd = GetBNForward(param, ctx, data, flags); + NDArray &data = in_data[batchnorm::kData]; + if (data.IsMKLDNNData() && data.IsView()) + data = data.Reorder2Default(); + auto data_mem = data.GetMKLDNNData(); + auto &fwd = GetBNForward(param, ctx, data_mem, flags); // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -221,7 +217,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } mkldnn_args_map_t net_args; - net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + net_args[MKLDNN_ARG_SRC] = *data_mem; net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; net_args[MKLDNN_ARG_DST] = *out_mem; if (fuse_relu) { diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 3bfc99ee4a88..2fafc7821b5e 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -295,7 +295,7 @@ def test_mkldnn_sum_inplace_with_cpu_layout(): @with_seed() def test_batchnorm(): def check_batchnorm_training(stype): - for shape in [(2, 3), (2, 3, 2, 2)]: + for shape in [(2, 3), (2, 4), (2, 3, 2, 2), (2, 4, 2, 2)]: data_tmp = np.random.normal(-0.1, 0.1, size=shape) s = shape[1], gamma = np.ones(s) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 7bacb4f0b317..49b84a2b9d68 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -20,6 +20,7 @@ import mxnet as mx from mxnet import gluon +from mxnet import init from mxnet.gluon import nn from mxnet.base import py_str, MXNetError from mxnet.test_utils import assert_almost_equal, default_context @@ -2179,6 +2180,40 @@ def hybrid_forward(self, F, x): check_layer_forward_withinput(net, x) +@with_seed() +def test_batchnorm_chnls(): + chn_list = [1024, 512, 256, 128, 64, 45, 32, 16, 3] + class Net(gluon.HybridBlock): + def __init__(self, + chn_num, + norm_kwargs=None, + in_channels=3, + **kwargs): + super(Net, self).__init__(**kwargs) + self.in_channels = in_channels + self.conv1 = gluon.nn.Conv3D( + in_channels=self.in_channels, + channels=chn_num, + kernel_size=(1, 7, 7), + strides=(1, 2, 2), + padding=(0, 3, 3), + use_bias=False, + ) + self.bn1 = gluon.nn.BatchNorm(in_channels=chn_num, **({} if norm_kwargs is None else norm_kwargs)) + + def hybrid_forward(self, F, x): + """Hybrid forward of R2+1D net""" + conv = self.conv1(x) + out = self.bn1(conv) + return out + + for i in range(len(chn_list)): + net = Net(chn_list[i]) + net.initialize(init=init.Constant(1)) + x = mx.nd.zeros((1, 3, 8, 160, 160)) + net(x).asnumpy() + + @with_seed() def test_concat(): chn_list = [16, 64]