Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[LayerNorm] Missing the mismatch cues of in_channels #17654

Closed
zheyuye opened this issue Feb 21, 2020 · 1 comment · Fixed by #17683
Closed

[LayerNorm] Missing the mismatch cues of in_channels #17654

zheyuye opened this issue Feb 21, 2020 · 1 comment · Fixed by #17683
Labels

Comments

@zheyuye
Copy link
Contributor

zheyuye commented Feb 21, 2020

Description

It seems that LayerNorm could work through even the setting of in_channels is wrong. As seen in the reproducible code snippet below, I am setting the parameters in_channels as 768 purposely in all cases which are unmatched receiving a input whose whose dimension of last axis is 1024. However, only the last of the three error cases would produce a "reasonable" error message.

I'm not entirely clear about the underlying implementation of nn.LayerNorm, and it make no sense to me that the first two cases are properly executable. I am wondering is there any chance to recheck the LayerNorm to generating an error message to infrom the user of the mismatch. It is now apparent that error messages occur only when there are other layers attached and the model is hybridized.

The above thinking and experimental process were inspired by a typo in the [SQUAD fine-tuning scripts of XLNET, which may need to be corrected. Surprisingly, this is a runable script even if the units size of xlnet large is 1024.

https://github.com/dmlc/gluon-nlp/blob/137e6b16bc1e672c6963a1e2ed754357e5a2ba11/scripts/language_model/model/qa.py#L37-L46

To Reproduce

import mxnet as mx
from mxnet.gluon import HybridBlock,nn
mx.npx.set_np()

class Foobar(HybridBlock):
    def __init__(self, units, prefix=None, params=None):
        super(Foobar, self).__init__(prefix=prefix, params=params)
        self.dense = nn.Dense(1, flatten=False)
        self.layernorm = nn.LayerNorm(epsilon=1e-12, in_channels=768)
    def hybrid_forward(self, F, x):
        out = self.layernorm(x)
        return out

class Foo(HybridBlock):
    def __init__(self, units, prefix=None, params=None):
        super(Foo, self).__init__(prefix=prefix, params=params)
        self.dense = nn.Dense(1, flatten=False)
        self.layernorm = nn.LayerNorm(epsilon=1e-12, in_channels=768)
    def hybrid_forward(self, F, x):
        out = self.layernorm(x)
        out = self.dense(out)
        return out

foo_0 = Foobar(units=1024)
foo_0.initialize(ctx=mx.gpu())
foo_0.hybridize()
out = foo_0(mx.np.random.normal(0,1,size=(10,1024), ctx=mx.gpu()))

foo_1 = Foo(units=1024)
foo_1.initialize(ctx=mx.gpu())
out = foo_1(mx.np.random.normal(0,1,size=(10,1024), ctx=mx.gpu()))

foo_2 = Foo(units=1024)
foo_2.initialize(ctx=mx.gpu())
foo_2.hybridize()
out = foo_2(mx.np.random.normal(0,1,size=(10,1024), ctx=mx.gpu()))

Error Message

DeferredInitializationError: Parameter 'dense2_weight' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.

During handling of the above exception, another exception occurred:
AssertionError: Expected shape (1024,) is incompatible with given shape (768,).

Comments

@sxjscience

@zheyuye zheyuye added the Bug label Feb 21, 2020
@sxjscience
Copy link
Member

@zheyuye The C++ side implementation of the shape inferring logic is here: https://github.com/apache/incubator-mxnet/blob/9dcf71d8fe33f77ed316a95fcffaf1f7f883ff70/src/operator/nn/layer_norm.cc#L39-L66

The python side is here: https://github.com/apache/incubator-mxnet/blob/9dcf71d8fe33f77ed316a95fcffaf1f7f883ff70/python/mxnet/gluon/nn/basic_layers.py#L609-L614

The problem should be to check the shape of gamma and beta:
https://github.com/apache/incubator-mxnet/blob/9dcf71d8fe33f77ed316a95fcffaf1f7f883ff70/src/operator/nn/layer_norm.cc#L56-L57

Would you try to investigate the issue? You can append std::cout << in_shape->at(layernorm::kGamma), which should not be empty when in_channel is given.

I think one way to solve the prioblem is to use the same SHAPE_ASSIGN_CHECK as here:
https://github.com/apache/incubator-mxnet/blob/9dcf71d8fe33f77ed316a95fcffaf1f7f883ff70/src/operator/numpy/np_where_op.cc#L42

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants