-
Notifications
You must be signed in to change notification settings - Fork 31.6k
[Whisper] Fix decoder ids methods #20599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper] Fix decoder ids methods #20599
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix
| self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps) | ||
| forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(self.prefix_tokens)] | ||
| return forced_decoder_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! I should have realized when reviewing!
| msg="`processor` and `feature_extractor` model input names do not match", | ||
| ) | ||
|
|
||
| def test_get_decoder_prompt_ids(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
* [Whisper] Fix decoder ids methods * enum property
What does this PR do?
The previous PR #20589 incorrectly returned a list of forced decoder ids:
Print Output:
The correct format is a nested list of decoder ids, where the first element of each list specifies the position of the forced token and the second the token id:
Print Output:
(at position 1 we force token 50257, at 2 we force 50358, at 3 we force 50362)
The PR also implements a test, thus making sure that no such error can be made again 😅
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.