Skip to content

Conversation

@hukuda222
Copy link
Contributor

@hukuda222 hukuda222 commented Jun 21, 2023

What does this PR do?

Diverse beam search is a method that generates num_beams//num_beam_groups sentences for each group independently. However, the current code uses one BeamHypotheses shared by all groups. Therefore, group A will generate two sentences before group B outputs a sentence. So, I created BeamHypotheses for each group so that inferences can be made independently.

Changes are as follows.

inference code:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-xsum")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-xsum")
text = "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration."
outputs = model.generate(
    tokenizer.encode(text, return_tensors="pt", max_length=512),
    num_beam_groups=2,
    num_beams=2,
    diversity_penalty=1000000.0,
    num_return_sequences=2,
)
print("\n".join(tokenizer.batch_decode(outputs, skip_special_tokens=True)))

before:

A number Of research projects have investigated the role of the brain's encoder and decoder in the control of the encoded sequences.
A number Of research projects have investigated the role of the brain's encoder and decoder in the control of the encoded sequences..

after:

The study of the activity of the brain's encoders and decoders has revealed a range of different models of how the brain processes information.
A number Of research projects have investigated the role of the brain's encoder and decoder in the control of the encoded sequences.

Fixes #24369

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@hukuda222 hukuda222 marked this pull request as ready for review June 22, 2023 01:26
Comment on lines +182 to 192
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
num_beams=self.group_size,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size)
for _ in range(batch_size * self.num_beam_groups)
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is the most important, and the rest of the changes are for consistency.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 22, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clean PR, @hukuda222 🙌

I have 3 further requests before I tag a core maintainer to greenlight this PR:
1 - rebase with main -- I suspect it is the cause for the failure in CI (lmk if you need instructions)
2 - run RUN_SLOW=1 py.test tests/generation/test_utils.py -vv -- I think the outputs in the integration test for group beam search will need to be updated due to these corrections. These tests may take a few minutes to run, depending on your machine :)
3 - Add "🚨🚨" to the PR title. We use it to flag output changes in our methods, so we don't forget to communicate about it in the next release 🤗

@gante
Copy link
Contributor

gante commented Jun 24, 2023

Summary of the problem and corresponding fix (for the core maintainer and our future selves)

Problem

The generation loop in group_beam_search is correct, and it builds num_beam_groups distinct groups of sequences. However, the beam_scorer.finalize() step was not taking num_beam_groups into consideration and the beam selection therein, when appending the last tokens, was free to write across groups. This should not happen at all, and it could entirely flush out the diversity in the different groups (when num_beam_groups >= num_beams/2), as we see in the example in the PR header.

Fix

Two different paths were possible: a) add logic to finalize to handle groups correctly; b) treat each group as an independent set of hypotheses. From the paper, we can read "we divide the beam budget B into G groups and greedily optimize each group using beam search", so option b), kindly implemented by @hukuda222, is closer to the reference.

@hukuda222 hukuda222 changed the title Fix group beam search 🚨🚨 Fix group beam search Jun 25, 2023
@hukuda222
Copy link
Contributor Author

@gante
Thanks for the review, CI now passes, and I confirmed that RUN_SLOW=1 py.test tests/generation/test_utils.py -vv also passes.

@gante gante requested a review from sgugger June 25, 2023 09:51
@gante
Copy link
Contributor

gante commented Jun 25, 2023

@sgugger this comment summarizes the problem and the fix

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR. I'm not clear on what the breaking changes are here (since there are 🚨 in the PR description). It should be put very visibly in the first comment with the way to enable the previous behavior if there are any.

@gante
Copy link
Contributor

gante commented Jun 26, 2023

@sgugger the breaking changes here in the generated outputs from group_beam_search, which are inevitable due to the bug fix. The method was underperforming (measured in log scores AND beam diversity, which is the point of the method) before these changes.

Since it is a bug fix, there is no need to ensure retro compatibility, correct?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for a bug fix indeed. Thanks for clarifying!

@gante gante merged commit 43479ef into huggingface:main Jun 27, 2023
@gante gante mentioned this pull request Jun 29, 2023
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Additional option for text generation when setting num_beam_groups

4 participants