-
Notifications
You must be signed in to change notification settings - Fork 6.7k
adding search.PrefixConstrainedBeamSearch #2646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding search.PrefixConstrainedBeamSearch #2646
Conversation
|
The test failed on something that is not part of the pull request |
myleott
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can ignore the test_translation_multi_simple_epoch test failure (the psutil import failure has been fixed in trunk).
But the test_ensemble_sequence_generator (tests.test_sequence_generator.TestJitSequeneceGenerator) failures seems related (see comment below)
| if num_remaining_sent == 0: | ||
| break | ||
| if isinstance(self.search, search.PrefixConstrainedBeamSearch) and step >= max_len: | ||
| if self.search.stop_on_max_len and step >= max_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@myleott has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@myleott has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
I am not a Facebook employee so I cannot see the warnings and why this fails. |
I'm taking care of this :) |
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? This adds a new decoding strategy `search.PrefixConstrainedBeamSearch` that limits the vocabulary of the next token generation given a prefix (that is the previously generated tokens during beam search). An end user has just to give the optional argument `prefix_allowed_tokens_fn` to `.generate` or `.sample` to activate `PrefixConstrainedBeamSearch`. `prefix_allowed_tokens_fn(batch_id, tokens)` is a callback function that given the `batch_id` and `tokens` returns the list of allowed token for the next generation step. ## Did you have fun? YES! � Pull Request resolved: facebookresearch/fairseq#2646 Reviewed By: fabiopetroni Differential Revision: D24006805 Pulled By: myleott fbshipit-source-id: 40b1a866c6ea9f936272db27e2a020b18dbf8164
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? This adds a new decoding strategy `search.PrefixConstrainedBeamSearch` that limits the vocabulary of the next token generation given a prefix (that is the previously generated tokens during beam search). An end user has just to give the optional argument `prefix_allowed_tokens_fn` to `.generate` or `.sample` to activate `PrefixConstrainedBeamSearch`. `prefix_allowed_tokens_fn(batch_id, tokens)` is a callback function that given the `batch_id` and `tokens` returns the list of allowed token for the next generation step. ## Did you have fun? YES! � Pull Request resolved: facebookresearch/fairseq#2646 Reviewed By: fabiopetroni Differential Revision: D24006805 Pulled By: myleott fbshipit-source-id: 40b1a866c6ea9f936272db27e2a020b18dbf8164
Before submitting
What does this PR do?
This adds a new decoding strategy
search.PrefixConstrainedBeamSearchthat limits the vocabulary of the next token generation given a prefix (that is the previously generated tokens during beam search). An end user has just to give the optional argumentprefix_allowed_tokens_fnto.generateor.sampleto activatePrefixConstrainedBeamSearch.prefix_allowed_tokens_fn(batch_id, tokens)is a callback function that given thebatch_idandtokensreturns the list of allowed token for the next generation step.Did you have fun?
YES! 🙃