-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Add unpack_inputs decorator for ctrl #16242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution 🔥 Have a look at my comment, as it is an important thing for us at Hugging Face.
Other than that, ready to merge 🚀
| self, | ||
| input_ids=None, | ||
| past=None, | ||
| past_key_values=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is well-intentioned and technically correct, but I'm going to ask you to revert past_key_values to past. It changes the public interface of the model, which may disrupt downstream users :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see and understand, only problem is, without past_key_values in the parameters, the tests will fail. Can they both (past and past_key_values) be in the parameter list?
And what about the other places where I changed it, like line 570, should that be reverted too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, my comment applies for all instances of past that got replaced with past_key_values. Does the problem persist after replacing them all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the problem arose when I first replaced all the input_processing related code.
It will throw a ValueError, that arises from the input_processing method (line 436 in modeling_tf_utils.py), because past_key_values will remain in kwargs_call. That is due to the fact that past is not in kwargs passed to the input_processing function. (I did some test debugging there)
From what i could reconstruct past_key_values got in kwargs_call in the run_call_with_unpacked_inputs method, because it is not initially in the signature of the functions, so my first idea at a fix was to put it in there.
Then I went a bit overboard and replaced all the past variables ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I see what's going on! It's actually related to a recent refactor we are doing on the TF side -- bringing our TF generate() up to speed with FLAX/PyTorch. In that PR, in the prepare_inputs_for_generation() function, the output dictionary key was updated from past to past_key_values. The output of that function is then fed to the model, explaining the issue you see. This is a great example of the problem of changing interfaces.
The new planned interface for generate() does rely on past_key_values, not on past, although most models don't use it as an explicit keyword argument. Normally, some sort of deprecation warning should be added, but since this argument is mostly for internal use (through the public generate()), there should be no need. I will take responsibility if a few users complain :)
Thank you so much for having the patience to explain the rationale behind your change, it helped me understand the issue faster 💪
(Although now I'm curious -- how did the model not break before? 🤔 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I also wondered why the tests didn't break before 🤷♂️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a todo for me to ensure we add a new test :)
|
Going to run the slow tests locally, to confirm they pass |
|
Can confirm that they pass, merging 👍 |
* add unpack_inputs decorator for ctrl * replace "past" with "past_key_values" Co-authored-by: Johannes Kolbe <johannes.kolbe@tech.better.team>
What does this PR do?
Add the
unpack_inputsdecorator forctrl. It also replaces thepastparameters in models in and output bypast_key_valuesas there was an irregularity in the naming that caused an error with the new input processing.Fixes # (issue)
#16051
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante
Tests
I ran
RUN_SLOW=1 py.test -vv tests/ctrl/test_modeling_tf_ctrl.pybut it only came to around 69% and failed the pre-trained model test, because it is too big for my local machine to test it.