Skip to content

Conversation

@ayushtiku5
Copy link
Contributor

@ayushtiku5 ayushtiku5 commented Nov 18, 2020

What does this PR do?

Implementation of diverse beam search decoding as described in the paper: https://arxiv.org/pdf/1610.02424.pdf

diversity function reference taken from: https://github.com/ashwinkalyan/dbs

Implementation details

Consider a T5 summarization task.
article="Justin Timberlake and Jessica Biel, welcome to parenthood. The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. "Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both."

Generation using normal beam search can be done as:
model.generate( input_ids=input_ids, num_beams=2, num_return_sequences=2 )

This generates:
['the couple announced the pregnancy in January. it is the first baby for both.', 'the couple announced the pregnancy in January. it is the first baby for both of them ']

Generation using diverse beam search can be done as:
model.generate( input_ids=input_ids, num_beams=2, num_return_sequences=2, beam_groups=2, diversity_penalty=1.5 )

This generates:
['the couple announced the pregnancy in January. it is the first baby for both.', 'Justin Timberlake and Jessica Biel have welcomed their son, Silas Randall ']

This means that 2 beams will be divided into 2 groups of 1 beam each, ensuring diversity between each group. NOTE: If beam_groups=1, then it will be same as the normal beam search as all the beams belong to the same group. Higher diversity_penalty will ensure more diversity between the groups of beams. When doing generation using diverse beam search, we need to ensure that num_beams>=beam_groups and also num_beams is divisible by beam_groups.

Who can review?

@patrickvonplaten, @TevenLeScao

@ayushtiku5
Copy link
Contributor Author

ayushtiku5 commented Nov 18, 2020

@patrickvonplaten I am implementing diverse beam search. Please do suggest code design for this. 😃

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten I am implementing diverse beam search. Please do suggest code design for this.

Awesome that you work on this!

I think this looks like the right approach! However, I'd also recommend creating a new beam_scorer to be sure to not break backwards compatilibily. We can see at a later stage if we can try to merge some code together with the current beam search code :-)

Also, can you add a link to the paper in this PR ? this would be great :-)

@ayushtiku5 ayushtiku5 marked this pull request as ready for review November 21, 2020 10:05
@ayushtiku5
Copy link
Contributor Author

@patrickvonplaten please review. I have made the required changes :)

@ayushtiku5
Copy link
Contributor Author

@patrickvonplaten just a gentle reminder to review the PR. Thanks!

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten just a gentle reminder to review the PR. Thanks!

Sorry, I'll review the PR this week! Also wondering how this PR relates to this one: #8840

@ayushtiku5
Copy link
Contributor Author

@patrickvonplaten I think #8840 ensures that first token of every predicted sequence is different. This PR ensures diversity between group of beams at every time step of sequence generation. I think this will be more generic. Also we can change extent of diversity using diversity_penalty parameter.

@patrickvonplaten patrickvonplaten linked an issue Dec 1, 2020 that may be closed by this pull request
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 7, 2020

@patrickvonplaten Also I was thinking that currently I am subtracting the diversity penalty directly from the beam_scores. So, finally when we are doing beam_scorer.finalize(), the final_beam_scores will also include the effect of diversity_penalty.

I was thinking maybe we should penalise the beam_scores with diversity penalty only when we are selecting top 2*group_size beam candidates:
next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True )

But for choosing the final beams in the end the scores shouldn't include the penalty due to diversity. What do you think?

Hey @ayushtiku5,

That's a good point! I do think though that we should leave the beam_scores as there are in the end as well. My main arguments are:

  1. It helps to have more diversity in the output. If we only use the diversity penalty for choosing the next beam_token, but not add it to the _beam_scores, the beam_scores will be very high for beams of similar tokens, which I think is what we want to prevent here. I think beam_scores should be penalized for every token in the corresponding beam_idx that is also present in another beam_idx of the same beam_group. It's also more consistent and logical IMO: We should update the beam_score with the probability that the current beam_id was selected.

  2. It would be very ugly to implement and I'd like to avoid it...

Is that fine for you?

@ayushtiku5
Copy link
Contributor Author

@patrickvonplaten Also I was thinking that currently I am subtracting the diversity penalty directly from the beam_scores. So, finally when we are doing beam_scorer.finalize(), the final_beam_scores will also include the effect of diversity_penalty.
I was thinking maybe we should penalise the beam_scores with diversity penalty only when we are selecting top 2*group_size beam candidates:
next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True )
But for choosing the final beams in the end the scores shouldn't include the penalty due to diversity. What do you think?

Hey @ayushtiku5,

That's a good point! I do think though that we should leave the beam_scores as there are in the end as well. My main arguments are:

  1. It helps to have more diversity in the output. If we only use the diversity penalty for choosing the next beam_token, but not add it to the _beam_scores, the beam_scores will be very high for beams of similar tokens, which I think is what we want to prevent here. I think beam_scores should be penalized for every token in the corresponding beam_idx that is also present in another beam_idx of the same beam_group. It's also more consistent and logical IMO: We should update the beam_score with the probability that the current beam_id was selected.
  2. It would be very ugly to implement and I'd like to avoid it...

Is that fine for you?

@patrickvonplaten yeah sure, I am fine with this.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 7, 2020

@ayushtiku5 - hope it's ok that I fiddled quite a bit with your PR. The functionality is kept 1:1 the same (I added an integration test in the very beginning to be sure of that), but the design is slightly different with the main goal to keep the method as general as possible.

IMO, the PR is now good to merge :-) Could you take a final look at whether the new names and design is ok for you?

Afterward, we can think about a nice code snippet / use case to advertise the big new feature of transformers :-)
Awesome job!

@patrickvonplaten
Copy link
Contributor

@ayushtiku5 do you think the following code snippet could be a nice use case of diverse beam search?

from transformers import pipeline
summarizer = pipeline("summarization", model="sshleifer/distilbart-xsum-12-6")

ARTICLE = """Part of the Broad Road was closed to traffic on Sunday at about 18:00 GMT.
The three adults and three children have been taken to Altnagelvin Hospital with non
life-threatening injuries. The Fire Service, Northern Ireland Ambulance Service
and police attended the crash. The Broad Road has since been reopened."""

# normal beam search
summarizer(ARTICLE, num_return_sequences=2)
# => [' Five people, including three children, have been taken to hospital following a two-vehicle crash in Londonderry.',
# ' Five people, including three children, have been taken to hospital after a two-vehicle crash in Londonderry.']

# diverse beam search
summarizer(ARTICLE, num_return_sequences=2, num_beam_groups=6, diversity_penalty=10.0)
# => ['Three men are in hospital after a car and a lorry crashed in Londonderry.',
# 'Six pedestrians were injured when a car and two vehicles crashed in County Antrim.']

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to merge for me. What be nice if @LysandreJik @sgugger can check as well. Also cc @patil-suraj if you're interested.

@ayushtiku5
Copy link
Contributor Author

ayushtiku5 commented Dec 7, 2020

@ayushtiku5 - hope it's ok that I fiddled quite a bit with your PR. The functionality is kept 1:1 the same (I added an integration test in the very beginning to be sure of that), but the design is slightly different with the main goal to keep the method as general as possible.

IMO, the PR is now good to merge :-) Could you take a final look at whether the new names and design is ok for you?

Afterward, we can think about a nice code snippet / use case to advertise the big new feature of transformers :-)
Awesome job!

Hey @patrickvonplaten ,

Just one thing. In the BeamScorer's finalize() method, we are directly selecting top num_beams beams from the final_beam_scores. This assumes that the beam scores in final_beam_scores will be sorted in decreasing order for a particular batch_idx. However, this will not be the case for our diverse beam search. final_beam_scores will be sorted for the beams inside a particular group, but not necessarily for all the beams for a particular batch_idx. So, I think we will have to sort the final_beam_scores for every batch_idx. I did this previously here

The rest looks good to me. Thanks for refactoring!

[UPDATE]: added this in this commit

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 7, 2020

Hey @ayushtiku5,

sorry I forgot to mention on why I deleted those lines. IMO we don't need to add this functionality because it doesn't matter whether the scores are sorted or not. In this line:

def add(self, hyp: torch.LongTensor, sum_logprobs: float):
you can see that the add(...) method automatically keeps the best scores and throws out the worse scores. Since the loop goes through all scores anyway it does not matter IMO whether they are sorted or not.

What do you think? IMO, we can revert the last commit.

This reverts commit c99eb5a.
@ayushtiku5
Copy link
Contributor Author

Hey @ayushtiku5,

sorry I forgot to mention on why I deleted those lines. IMO we don't need to add this functionality because it doesn't matter whether the scores are sorted or not. In this line:

def add(self, hyp: torch.LongTensor, sum_logprobs: float):

you can see that the add(...) method automatically keeps the best scores and throws out the worse scores. Since the loop goes through all scores anyway it does not matter IMO whether they are sorted or not.
What do you think? IMO, we can revert the last commit.

Yeah sorry! I completely missed it. Reverted the commit.

@patrickvonplaten
Copy link
Contributor

Hey @ayushtiku5,
sorry I forgot to mention on why I deleted those lines. IMO we don't need to add this functionality because it doesn't matter whether the scores are sorted or not. In this line:

def add(self, hyp: torch.LongTensor, sum_logprobs: float):

you can see that the add(...) method automatically keeps the best scores and throws out the worse scores. Since the loop goes through all scores anyway it does not matter IMO whether they are sorted or not.
What do you think? IMO, we can revert the last commit.

Yeah sorry! I completely missed it. Reverted the commit.

No worries :-) The comment wasn't the best either - I updated it. Think it's a bit clearer now.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean to me! Just have a few nits here and there.
Thanks a lot for your PR!

@patrickvonplaten
Copy link
Contributor

@ayushtiku5 - super sorry, we messed up the previous branch yesterday. I opened a new PR with the same authorship -> so it should be good to merge :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Diverse Beam Search decoding

3 participants