-
Notifications
You must be signed in to change notification settings - Fork 32k
Update max_diff in test_save_load_fast_init_to_base
#19849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) | ||
|
|
||
| for key in model_fast_init.state_dict().keys(): | ||
| max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't sum the values, as the positive/negative values will cancel out. This is not reliable.
tests/test_modeling_common.py
Outdated
| for key in model_fast_init.state_dict().keys(): | ||
| max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() | ||
| if random_key_to_del.endswith(key): | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is reasonable, but need @sgugger to confirm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the test is there to check that the corresponding weight is initialized the same way with slow and fast init. This would make the test less useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I couldn't see the reason for which the weights (for that removed key) should have the same values with 2 different initialization methods. As these are random ops, and not running in the same sequence in these 2 cases.
And if we don't expect the same values but just close values (but not because of the numerical issue), it's not very clear to me how the threshold is determined .
(But I didn't really go through the whole logic around this part ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong about what the test is doing but I thought it was checking that. Though it would require setting the seed before both load to work. @patrickvonplaten could you shed some light here since you wrote it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting seed is a necessary but not sufficient condition in this case. As we change the way of initialization, which will change the position (in the sequence of random operations performed by torch) where a particular weight being initialized (here for the deleted key) - I believe.
If we want to obtain equal weight values with both init. methods, we will have to set the same seed just before the target weight(s). However, I feel this is not really trackable and not worth the effort.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming might not have been spot-on here. In short the mechanism and the rationale behind the test here is to test that the initialization not only works for initializing whole parts of the model (e.g. the model head, or the whole cross attention layer), but also when the checkpoint is missing a single weight / parameter. It's an edge cast test.
E.g. what we want to test here is the following scenario: The bias of the first feed-forward layer of the 5th block of BERT is removed from the state-dict, then this single bias parameter will be initialized when loading the state dict with from_pretrained(...) -> now this initialization should be the same for both the fast and slow init method. IMO we should definitely keep this line because removing it makes the test worse.
Does that make sense?
This should be unrelated to the seed because if I remember correctly the _init_weights weights is overwritten by a deterministic self._mock_init_weights function, so there is no randomness in play here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @patrickvonplaten !
I think the flakiness we see in some models then comes from the fact that some of the weights are not initialized in the _init_weights function. For instance ViT does not initialize its embeddings in _init_weights.
| continue | ||
| max_diff = torch.max( | ||
| torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) | ||
| ).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using abs but without sum, just max.
|
The documentation is not available anymore as the PR was closed or merged. |
|
Thank you @patrickvonplaten! Although I still have some doubt: Assumption we have:
What we do
But But good for me if we don't want to touch it. We probably need to add some common about the flakyness for some tests though. |
|
@sgugger Could you check if the change in the way |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
test_save_load_fast_init_to_basemax_diff in test_save_load_fast_init_to_base
What does this PR do?
This test seems flaky. After looking a bit deeper, I am not sure if we should expect to get the same (or very close) weights with/without
_fast_initfor the deleted key.transformers/tests/test_modeling_common.py
Line 343 in 9ecb13d
transformers/tests/test_modeling_common.py
Line 400 in 9ecb13d
My intuition is that the values for that deleted key could be different with 2 different init. methods.
I also change the way of
max_diffbeing calculated.