-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Replace past with past_key_values
#20944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace past with past_key_values
#20944
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! There seems to be an issue with the TensorFlow tests. Also would like to have @gante opinion on this before merging.
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
Although the change in the failing TF test is weird, we should try to understand it before we merge this PR 🤔
| @staticmethod | ||
| def _reorder_cache(past, beam_idx): | ||
| reordered_past = () | ||
| for layer_past in past: | ||
| # cached cross_attention states don't have to be reordered -> they are always the same | ||
| reordered_past += ( | ||
| tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], | ||
| ) | ||
| return reordered_past | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is no longer needed in the TF side :D (It was used in the code path that existed before the XLA transition)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I thought that the failing test might come from this but I was wrong 😉
…to remove_past_for_past_key
|
Ok the failing tests were because I did not pull from main, were the |
|
Yes, good to merge for me! |
What does this PR do?
The argument
pastwas completely replaced withpast_key_valuesthus this PR should fix any problem withkwargsbeing swallowed for old models in generation.Related to #20347