Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Oct 24, 2022

What does this PR do?

This test seems flaky. After looking a bit deeper, I am not sure if we should expect to get the same (or very close) weights with/without _fast_init for the deleted key.

random_key_to_del = random.choice(list(state_dict.keys()))

for key in model_fast_init.state_dict().keys():

My intuition is that the values for that deleted key could be different with 2 different init. methods.

I also change the way of max_diff being calculated.

model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)

for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
Copy link
Collaborator Author

@ydshieh ydshieh Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't sum the values, as the positive/negative values will cancel out. This is not reliable.

for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
if random_key_to_del.endswith(key):
continue
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is reasonable, but need @sgugger to confirm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the test is there to check that the corresponding weight is initialized the same way with slow and fast init. This would make the test less useful.

Copy link
Collaborator Author

@ydshieh ydshieh Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I couldn't see the reason for which the weights (for that removed key) should have the same values with 2 different initialization methods. As these are random ops, and not running in the same sequence in these 2 cases.

And if we don't expect the same values but just close values (but not because of the numerical issue), it's not very clear to me how the threshold is determined .

(But I didn't really go through the whole logic around this part ...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong about what the test is doing but I thought it was checking that. Though it would require setting the seed before both load to work. @patrickvonplaten could you shed some light here since you wrote it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting seed is a necessary but not sufficient condition in this case. As we change the way of initialization, which will change the position (in the sequence of random operations performed by torch) where a particular weight being initialized (here for the deleted key) - I believe.

If we want to obtain equal weight values with both init. methods, we will have to set the same seed just before the target weight(s). However, I feel this is not really trackable and not worth the effort.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming might not have been spot-on here. In short the mechanism and the rationale behind the test here is to test that the initialization not only works for initializing whole parts of the model (e.g. the model head, or the whole cross attention layer), but also when the checkpoint is missing a single weight / parameter. It's an edge cast test.

E.g. what we want to test here is the following scenario: The bias of the first feed-forward layer of the 5th block of BERT is removed from the state-dict, then this single bias parameter will be initialized when loading the state dict with from_pretrained(...) -> now this initialization should be the same for both the fast and slow init method. IMO we should definitely keep this line because removing it makes the test worse.

Does that make sense?
This should be unrelated to the seed because if I remember correctly the _init_weights weights is overwritten by a deterministic self._mock_init_weights function, so there is no randomness in play here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @patrickvonplaten !

I think the flakiness we see in some models then comes from the fact that some of the weights are not initialized in the _init_weights function. For instance ViT does not initialize its embeddings in _init_weights.

continue
max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item()
Copy link
Collaborator Author

@ydshieh ydshieh Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using abs but without sum, just max.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 24, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh requested a review from sgugger October 24, 2022 18:17
@ydshieh
Copy link
Collaborator Author

ydshieh commented Oct 26, 2022

Thank you @patrickvonplaten! Although I still have some doubt:

Assumption we have:

now this initialization should be the same for both the fast and slow init method

What we do

_init_weights weights is overwritten by a deterministic self._mock_init_weights

But _mock_init_weights is only defined in the testing module, and basically it just does data.fill_(3).
So the assumption is only True in our own testing (which uses _mock_init_weights). This won't be the case when we want to load the model outside the testing. So I am not very sure the purpose of this testing.

But good for me if we don't want to touch it. We probably need to add some common about the flakyness for some tests though.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Oct 26, 2022

@sgugger Could you check if the change in the way max_diff being calculated worth the merge 🙏 ? Thanks

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@ydshieh ydshieh merged commit 688c3e8 into main Oct 26, 2022
@ydshieh ydshieh deleted the fix_test_save_load_fast_init_to_base branch October 26, 2022 15:09
@ydshieh ydshieh changed the title Fix test_save_load_fast_init_to_base Update max_diff in test_save_load_fast_init_to_base Oct 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants