-
Notifications
You must be signed in to change notification settings - Fork 31.8k
🚨🚨 Fix group beam search #24407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨🚨 Fix group beam search #24407
Conversation
| # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch. | ||
| # If group_beam_search is not used, the list consists of `batch_size` beam_hyps. | ||
| self._beam_hyps = [ | ||
| BeamHypotheses( | ||
| num_beams=self.num_beams, | ||
| num_beams=self.group_size, | ||
| length_penalty=self.length_penalty, | ||
| early_stopping=self.do_early_stopping, | ||
| max_length=max_length, | ||
| ) | ||
| for _ in range(batch_size) | ||
| for _ in range(batch_size * self.num_beam_groups) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is the most important, and the rest of the changes are for consistency.
|
The documentation is not available anymore as the PR was closed or merged. |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the clean PR, @hukuda222 🙌
I have 3 further requests before I tag a core maintainer to greenlight this PR:
1 - rebase with main -- I suspect it is the cause for the failure in CI (lmk if you need instructions)
2 - run RUN_SLOW=1 py.test tests/generation/test_utils.py -vv -- I think the outputs in the integration test for group beam search will need to be updated due to these corrections. These tests may take a few minutes to run, depending on your machine :)
3 - Add "🚨🚨" to the PR title. We use it to flag output changes in our methods, so we don't forget to communicate about it in the next release 🤗
Summary of the problem and corresponding fix (for the core maintainer and our future selves)ProblemThe generation loop in FixTwo different paths were possible: a) add logic to |
Co-authored-by: Joao Gante <[email protected]>
|
@gante |
|
@sgugger this comment summarizes the problem and the fix |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR. I'm not clear on what the breaking changes are here (since there are 🚨 in the PR description). It should be put very visibly in the first comment with the way to enable the previous behavior if there are any.
|
@sgugger the breaking changes here in the generated outputs from Since it is a bug fix, there is no need to ensure retro compatibility, correct? |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for a bug fix indeed. Thanks for clarifying!
What does this PR do?
Diverse beam search is a method that generates
num_beams//num_beam_groupssentences for each group independently. However, the current code uses one BeamHypotheses shared by all groups. Therefore, group A will generate two sentences before group B outputs a sentence. So, I created BeamHypotheses for each group so that inferences can be made independently.Changes are as follows.
inference code:
before:
after:
Fixes #24369
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante