-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[Vision] [Refactor] Initialize weights on the correct place
#20803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Vision] [Refactor] Initialize weights on the correct place
#20803
Conversation
Vision] Initialize weights on the correct placeVision] [Refactor] Initialize weights on the correct place
|
The documentation is not available anymore as the PR was closed or merged. |
NielsRogge
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this cleaner!
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for this, let's just add a copied from for the __init__ function
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is way better to do it this way (otherwise using device_map="auto" or more generally _fast_init in from_pretrained results in those weights not being properly initialized.
| elif isinstance(module, ViTEmbeddings): | ||
| nn.init.trunc_normal_( | ||
| module.position_embeddings, | ||
| mean=0.0, | ||
| std=self.config.initializer_range, | ||
| ) | ||
|
|
||
| nn.init.trunc_normal_( | ||
| module.cls_token, | ||
| mean=0.0, | ||
| std=self.config.initializer_range, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada do you mind filling me in a bit about this refactor? I would greatly appreciate it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @sayakpaul !
I just moved the lines above here, before this refactor the attributes position_embeddings & cls_token were initialized on the fly during their initialization, i.e. whenever we create an instance of ViTEmbedding.
But this approach is not what we want to follow since it is preferable to centralize all the weight initialization process inside the method _init_weights, for example to call it only when we need it! Consider this snippet:
from transformers import ViTModel
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
I think what we want ideally here is to:
1- Create an instance of ViTModel
2- Load the pre-trained weights, straight to the goal
There is no need to fill the values of the weights with an intendent distribution here, since the weights will be filled at the end by the pre-trained weights. Also note that not calling _init_weights will speedup loading large models (this is I believe one of the main reason why it is disabled by default on .from_pretrained method).
Therefore that is why we always prefer to do it in two stages:
1- initialize each module and submodules
2- fill the weights of these modules with the correct distribution if needed
In rare cases you can face unexpected behaviours when doing everything in 1- : e.g. before this PR if you load a ViT model with & torch_dtype=torch.float16 you'll face an error:
RuntimeError: "erfinv_vml_cpu" not implemented for 'Half'
That you can reproduce with this snippet:
import torch
import torch.nn as nn
torch.set_default_dtype(torch.float16)
nn.init.trunc_normal_(
torch.zeros(1, 1, 2),
mean=0.0,
std=0.1,
)
This is simply because torch.set_default_dtype(torch.float16) is called when adding torch_dtype=torch.float16, and such errors can be very confusing for users!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beauty! Thanks for being so generous with your explanation!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada thank you for tracking the issue
…gface#20803) * fix nit - initialization on `_init_weights` - fix copies * add copied from
What does this PR do?
This PR forces some modules to be initialised on the correct place (i.e. on the
_init_weightsmethod).With more vision models being added, contributors are copying the practice to initialise some weights outside
_init_weights. I think that we should centralize weights initialisation on the_init_weightsmethod, by applying this on most-copied / downloaded models.Related: