From 83dfd4cbf659dd7516d5a1335efa3c948abc4409 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Mon, 24 Feb 2020 22:56:06 -0800 Subject: [PATCH 1/2] Update layer_norm.cc add test case for error checking --- src/operator/nn/layer_norm.cc | 10 ++++++---- tests/python/unittest/test_gluon.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 21ec52515983..91118e4b2bbc 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -52,10 +52,12 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, if (!mxnet::ndim_is_known(dshape)) { return false; } - - in_shape->at(layernorm::kGamma) = mxnet::TShape(Shape1(channelCount)); - in_shape->at(layernorm::kBeta) = mxnet::TShape(Shape1(channelCount)); - + SHAPE_ASSIGN_CHECK(*in_shape, + layernorm::kGamma, + mxnet::TShape(Shape1(channelCount))); + SHAPE_ASSIGN_CHECK(*in_shape, + layernorm::kBeta, + mxnet::TShape(Shape1(channelCount))); out_shape->clear(); out_shape->push_back(dshape); // kOut mxnet::TShape moments_shape(dshape.begin(), dshape.end()); diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 37b9cd7a0697..a02682557954 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -21,7 +21,7 @@ import mxnet as mx from mxnet import gluon from mxnet.gluon import nn -from mxnet.base import py_str +from mxnet.base import py_str, MXNetError from mxnet.test_utils import assert_almost_equal from mxnet.util import is_np_array from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID @@ -894,7 +894,13 @@ def test_instancenorm(): def test_layernorm(): layer = nn.LayerNorm(in_channels=10) check_layer_forward(layer, (2, 10, 10, 10)) - + # Check for the case of error raising + for hybridize in [False, True]: + layer = nn.LayerNorm(in_channels=10) + layer.initialize() + if hybridize: + layer.hybridize() + assert_raises(MXNetError, lambda: layer(mx.nd.ones((2, 11)))) @with_seed() def test_groupnorm(): From 77b8a4a13292098e843d16b0b35974132e33f2ac Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Mon, 24 Feb 2020 23:44:50 -0800 Subject: [PATCH 2/2] fix indent --- src/operator/nn/layer_norm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 91118e4b2bbc..d385b93e9cff 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -53,8 +53,8 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, return false; } SHAPE_ASSIGN_CHECK(*in_shape, - layernorm::kGamma, - mxnet::TShape(Shape1(channelCount))); + layernorm::kGamma, + mxnet::TShape(Shape1(channelCount))); SHAPE_ASSIGN_CHECK(*in_shape, layernorm::kBeta, mxnet::TShape(Shape1(channelCount)));