diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 21ec52515983..d385b93e9cff 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -52,10 +52,12 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, if (!mxnet::ndim_is_known(dshape)) { return false; } - - in_shape->at(layernorm::kGamma) = mxnet::TShape(Shape1(channelCount)); - in_shape->at(layernorm::kBeta) = mxnet::TShape(Shape1(channelCount)); - + SHAPE_ASSIGN_CHECK(*in_shape, + layernorm::kGamma, + mxnet::TShape(Shape1(channelCount))); + SHAPE_ASSIGN_CHECK(*in_shape, + layernorm::kBeta, + mxnet::TShape(Shape1(channelCount))); out_shape->clear(); out_shape->push_back(dshape); // kOut mxnet::TShape moments_shape(dshape.begin(), dshape.end()); diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 37b9cd7a0697..a02682557954 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -21,7 +21,7 @@ import mxnet as mx from mxnet import gluon from mxnet.gluon import nn -from mxnet.base import py_str +from mxnet.base import py_str, MXNetError from mxnet.test_utils import assert_almost_equal from mxnet.util import is_np_array from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID @@ -894,7 +894,13 @@ def test_instancenorm(): def test_layernorm(): layer = nn.LayerNorm(in_channels=10) check_layer_forward(layer, (2, 10, 10, 10)) - + # Check for the case of error raising + for hybridize in [False, True]: + layer = nn.LayerNorm(in_channels=10) + layer.initialize() + if hybridize: + layer.hybridize() + assert_raises(MXNetError, lambda: layer(mx.nd.ones((2, 11)))) @with_seed() def test_groupnorm():