Skip to content

Conversation

@ADAning
Copy link
Contributor

@ADAning ADAning commented Jul 3, 2022

What does this PR do?

Official summrization examples use T5 as pretrained model, but the name of T5 layer norm is layer_norm, not layerNorm:

from transformers import T5Tokenizer, T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("t5-small")
for n, p in model.named_parameters():
    print(n)

Output:

...
encoder.block.0.layer.0.layer_norm.weight
encoder.block.0.layer.1.DenseReluDense.wi.weight
encoder.block.0.layer.1.DenseReluDense.wo.weight
encoder.block.0.layer.1.layer_norm.weight
...

In Official example of summarization, layer_norm not included:

no_decay = ["bias", "LayerNorm.weight"]

A similar problem occurred in trainer.py, which may cause Seq2SeqTrainer train T5 layer norm with weight decay:

if self.optimizer is None:
            decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
            decay_parameters = [name for name in decay_parameters if "bias" not in name]

Because T5 uses T5LayerNorm Layer norm notnn.LayerNorm:

class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 3, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR, I've added a few comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep both here, since the user might use other models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is not something we really want to add as the Trainer shouldn't depend on individual model files. Instead T5LayerNorm should subclass torch.nn.LayerNorm (it rewrites the init and the forward anyway)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments! It seems that no one has written any types of layer norm as subclass of nn.LayerNorm. I'm not sure if this is true:

class T5LayerNorm(nn.LayerNorm):  
    def __init__(self, hidden_size, eps):  
        """  
        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.        """        super().__init__(hidden_size, eps, elementwise_affine=False)  
        self.weight = nn.Parameter(torch.ones(hidden_size))  
        self.variance_epsilon = eps  
  
        self.reset_parameters()  
  
    def forward(self, hidden_states):  
  
        # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean  
        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for        # half-precision inputs is done in fp32  
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)  
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  
  
        # convert into half-precision if necessary  
        if self.weight.dtype in [torch.float16, torch.bfloat16]:  
            hidden_states = hidden_states.to(self.weight.dtype)  
  
        return self.weight * hidden_states  
  
    def reset_parameters(self):  
        if self.weight is not None:  
            nn.init.ones_(self.weight)

LongT5LayerNorm is copied from T5LayerNorm which need to be changed same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would create unused bias weights for every layernorm in T5 at the moment, I don't think we want this. @sgugger I'm actually not really in favor of adding this abstraction here to the T5LayerNorm class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if the init calls the superclass. If it calls the super of the super class (nn.Module), there is nothing breaking. We just can't have the Trainer start depending on multiple modeling files because there are 100 flavors of LayerNorms.

An alternative would be to have a constant in pt_utils where each submodule that defines a specific LayerNorm adds their own, would that work better?

So here we would have after defining LongT5LayerNorm:

ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ADAning Could you amend your PR to use this solution maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does pt_utils mean transformers.utils ?

@ADAning Could you amend your PR to use this solution maybe?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No transformers.pytorch_utils. Sorry I misspelled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actuall this is my first time to use git... I don't know why other people commits can be seen after I fetch and rebase from upstream, it really bothers me. So I reset my commit to I created this PR before, then this PR closed. Then my commit can't be seen there. Could you re-open this PR? @sgugger
Im sorry for my mistake.

@sgugger sgugger reopened this Jul 6, 2022
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just re-opened your PR! In general you only need to rebase if the PR has gone a long time (a month) or if there is some critical changes in the main branch you need, so don't bother for small changes :-)

return self.weight * hidden_states


ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be after the try except block below, which may change T5LayerNorm.

return self.weight * hidden_states


ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be after the try/except block below, which can change LongT5LayerNorm.

@sgugger
Copy link
Collaborator

sgugger commented Jul 6, 2022

Thanks a lot for iterating with us!

@sgugger sgugger merged commit bf37e5c into huggingface:main Jul 6, 2022
@ADAning ADAning deleted the t5-layer-norm branch July 6, 2022 14:03
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for iterating on this and sorry to reply so late here - the final solution looks very nice!

viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
…xample (huggingface#18002)

* Add ALL_LAYERNORM_LAYERS for LayerNorm

* fix bug of appending layer norm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants