-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Update past_key_values in GPT-2
#9596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update past_key_values in GPT-2
#9596
Conversation
|
CircleCI error messages says as below. In In |
|
Is there a difference between I first thought it might be a difference between the Causal language model and the Seq2Seq language model, but it seems that both And as for the contents of transformers/src/transformers/models/bart/modeling_bart.py Lines 1236 to 1244 in 236cc36
|
|
I've updated transformers/src/transformers/models/xlnet/modeling_xlnet.py Lines 581 to 607 in 236cc36
It seems |
|
Hey @forest1988, You're PR looks very nice! Yes, it is expected that
def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(...) |
|
I've just updated |
|
This way it's much cleaner and correct :-) The reason I'm proposing this change is that the def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}") |
|
I think this should solve the problems, let me know if you need more help :-) |
|
Thank you for your advice! I'll update |
89ee453 to
d04b10c
Compare
|
Thanks to your kind advice, I could solve the problem of The last one remaining bug is: I think I should modify
|
|
All checks have passed! However, in the documentation of |
past_key_values in GPT-2past_key_values in GPT-2
| called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every | ||
| generation step. | ||
|
|
||
| For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove those lines and past_key_values above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cleaned it as well.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks very nice - thanks so much for taking the time to tackle this @forest1988 . Let's wait a bit to see how to proceed with gradient_checkpointing in GPT2 as this question will come up more often. IMO, use_cache should always be False for training so either we update all use_cache in the models with a use_cache= not self.is_training and (use_cache if use_cache is not None else self.config.use_cache) or we force it somehow in the Trainer. Similarly gradient_checkpointing should never be set to True when the model is not training IMO (we could also automatically disable this using self.training). Let's see what @LysandreJik and @sgugger think.
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a part of the library I'm very familiar with, so the changes look okay on my side, but I'm no expert.
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes look good to me! Thanks for taking care of it @forest1988.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @forest1988,
I hope it's fine for you that I went into the PR to do some final fixes. Thanks a lot for cleaning this up :-)
Of course! Thank you for adding fixes to make this PR more valuable! |
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your commit looks good to me @patrickvonplaten! Thanks.
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new changes look good to me, thanks!
|
Awesome, merging - great job @forest1988 ! |
|
Thank you for your advice and encouraging comments! |
|
|
||
| if use_cache is True: | ||
| present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking | ||
| present = (key.transpose(-2, -1), value) # transpose to have same shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the reason for the recent failure of the slow test:
RUN_SLOW=1 pytest tests/test_onnx.py::OnnxExportTestCase::test_export_pytorch
Can you fix the onnx part easily? @mfuntowicz @Narsil
What does this PR do?
It seems GPT-2 and BartDecoder has a different style of
past_key_values.Advised by @patrickvonplaten,
I opened this PR to change GPT-2's cache format from a single tensor to a tuple of 2 tensors.
Once this problem is solved, it is expected that
past_key_valuesin GPT-2 will be handled in the same way as in Bart.Sorry there remain some errors. This PR is [WIP].
I would appreciate your advice on how to update
generation_utils.py.Can I modify
_reorder_cacheso that past is replaced from Tuple[torch.Tensor] to Tuple[Tuple[torch.Tensor]],or should I consider other output variations, output.mem and outputs.past_buckets_states?
Fixes #9391
From patrickvonplaten:
This PR cleans the
_reorder_cachelogic. Now_reorcher_cachedefaults to an erroneousNotImplementedErroringeneration_utils.pyforcing the model to implement its corresponding_rerorder_cacheit themodeling_...pyfile itself. This is cleaner as_reorder_cachestrongly differs from model to model. In addition, this PR makes sure thatgradient_checkpointingcan only be used if the model is in training mode and makes sure thatuse_cacheis disabled when training andgradient_checkpointingis enabled to prevent errors.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
GPT2: @LysandreJik, @patrickvonplaten