Skip to content

Conversation

@younesbelkada
Copy link
Contributor

What does this PR do?

This PR forces some modules to be initialised on the correct place (i.e. on the _init_weights method).
With more vision models being added, contributors are copying the practice to initialise some weights outside _init_weights. I think that we should centralize weights initialisation on the _init_weights method, by applying this on most-copied / downloaded models.

Related:

- initialization on `_init_weights`
- fix copies
@younesbelkada younesbelkada changed the title [Vision] Initialize weights on the correct place [Vision] [Refactor] Initialize weights on the correct place Dec 16, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 16, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this cleaner!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for this, let's just add a copied from for the __init__ function

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is way better to do it this way (otherwise using device_map="auto" or more generally _fast_init in from_pretrained results in those weights not being properly initialized.

@younesbelkada younesbelkada merged commit ecd7de3 into huggingface:main Dec 19, 2022
Comment on lines +454 to +465
elif isinstance(module, ViTEmbeddings):
nn.init.trunc_normal_(
module.position_embeddings,
mean=0.0,
std=self.config.initializer_range,
)

nn.init.trunc_normal_(
module.cls_token,
mean=0.0,
std=self.config.initializer_range,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada do you mind filling me in a bit about this refactor? I would greatly appreciate it.

Copy link
Contributor Author

@younesbelkada younesbelkada Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @sayakpaul !

I just moved the lines above here, before this refactor the attributes position_embeddings & cls_token were initialized on the fly during their initialization, i.e. whenever we create an instance of ViTEmbedding.

But this approach is not what we want to follow since it is preferable to centralize all the weight initialization process inside the method _init_weights, for example to call it only when we need it! Consider this snippet:

from transformers import ViTModel

model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

I think what we want ideally here is to:
1- Create an instance of ViTModel
2- Load the pre-trained weights, straight to the goal
There is no need to fill the values of the weights with an intendent distribution here, since the weights will be filled at the end by the pre-trained weights. Also note that not calling _init_weights will speedup loading large models (this is I believe one of the main reason why it is disabled by default on .from_pretrained method).

Therefore that is why we always prefer to do it in two stages:
1- initialize each module and submodules
2- fill the weights of these modules with the correct distribution if needed
In rare cases you can face unexpected behaviours when doing everything in 1- : e.g. before this PR if you load a ViT model with & torch_dtype=torch.float16 you'll face an error:

RuntimeError: "erfinv_vml_cpu" not implemented for 'Half'

That you can reproduce with this snippet:

import torch
import torch.nn as nn

torch.set_default_dtype(torch.float16)

nn.init.trunc_normal_(
    torch.zeros(1, 1, 2),
    mean=0.0,
    std=0.1,
)

This is simply because torch.set_default_dtype(torch.float16) is called when adding torch_dtype=torch.float16, and such errors can be very confusing for users!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beauty! Thanks for being so generous with your explanation!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada thank you for tracking the issue

silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
…gface#20803)

* fix nit

- initialization on `_init_weights`
- fix copies

* add copied from
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants