From 14d5b6ca6e527eac2cdb9e9400d4c00f6d7add01 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Tue, 17 Nov 2020 15:50:32 +0530 Subject: [PATCH 01/31] diverse beam search --- src/transformers/generation_beam_search.py | 40 ++++--- src/transformers/generation_utils.py | 116 ++++++++++++++------- 2 files changed, 106 insertions(+), 50 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 135227895d89..47a42f7e18d7 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -95,6 +95,7 @@ class BeamScorer(ABC): def process( self, input_ids: torch.LongTensor, + group_size: int, next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, @@ -141,6 +142,13 @@ class BeamSearchScorer(BeamScorer): num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): The number of beam hypotheses that shall be returned upon calling :meth:`~transformer.BeamSearchScorer.finalize`. + beam_groups (:obj:`int`, `optional`, defaults to 1): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups + of beams. See `this paper + `__ for more details. + diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. """ def __init__( @@ -152,6 +160,8 @@ def __init__( length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, num_beam_hyps_to_keep: Optional[int] = 1, + beam_groups: Optional[int] = 1, + diversity_penalty: Optional[float] = 0.0 ): self.max_length = max_length self.num_beams = num_beams @@ -159,6 +169,8 @@ def __init__( self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + self.beam_groups = beam_groups # when beam_groups=1 it is same as normal beam search + self.diversity_penalty = diversity_penalty self._is_init = False self._beam_hyps = [ @@ -184,6 +196,7 @@ def is_done(self) -> bool: def process( self, input_ids: torch.LongTensor, + group_size: int, next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, @@ -192,12 +205,12 @@ def process( ) -> Tuple[torch.Tensor]: cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // self.num_beams) + assert batch_size == (input_ids.shape[0] // group_size) device = input_ids.device - next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) + next_beam_scores = torch.zeros((batch_size, group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, group_size), dtype=next_indices.dtype, device=device) for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -218,11 +231,11 @@ def process( for beam_token_rank, (next_token, next_score, next_index) in enumerate( zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) ): - batch_beam_idx = batch_idx * self.num_beams + next_index + batch_beam_idx = batch_idx * group_size + next_index # add to generated hypotheses if end of sentence if (eos_token_id is not None) and (next_token.item() == eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + is_beam_token_worse_than_top_num_beams = beam_token_rank >= group_size if is_beam_token_worse_than_top_num_beams: continue beam_hyp.add( @@ -237,12 +250,12 @@ def process( beam_idx += 1 # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.num_beams: + if beam_idx == group_size: break - if beam_idx < self.num_beams: + if beam_idx < group_size: raise ValueError( - f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + f"At most {group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." ) # Check if we are done so that we can save a pad step if all(done) @@ -268,16 +281,19 @@ def finalize( eos_token_id: Optional[int] = None, ) -> torch.LongTensor: batch_size = len(self._beam_hyps) + final_beam_scores = final_beam_scores.view((batch_size, self.num_beams)) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue + batch_beam_scores = final_beam_scores[batch_idx, :] + _, beam_ids = torch.sort(batch_beam_scores, descending=True) # need to add best num_beams hypotheses to generated hyps - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id - final_score = final_beam_scores[batch_beam_idx].item() + for beam_id in beam_ids: + batch_beam_idx = batch_idx * self.num_beams + beam_id.item() + final_score = batch_beam_scores[beam_id.item()].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 206658da98ad..d754ab025c4f 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -300,6 +300,8 @@ def generate( num_return_sequences: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, + beam_groups: Optional[int] = 1, + diversity_penalty: Optional[float] = 0.0, **model_kwargs ) -> torch.LongTensor: r""" @@ -366,6 +368,13 @@ def generate( use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + beam_groups (:obj:`int`, `optional`, defaults to 1): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups + of beams. See `this paper + `__ for more details. + diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific @@ -537,6 +546,9 @@ def generate( if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + if num_beams % beam_groups != 0: + raise ValueError("`num_beams` should be divisible by `beam_groups` for diverse beam search.") + beam_scorer = BeamSearchScorer( batch_size=batch_size, max_length=max_length, @@ -545,6 +557,8 @@ def generate( length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, + beam_groups=beam_groups, + diversity_penalty=diversity_penalty ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -940,6 +954,9 @@ def beam_search( batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams + beam_groups = beam_scorer.beam_groups + diversity_penalty = beam_scorer.diversity_penalty + num_sub_beams = num_beams // beam_groups batch_beam_size, cur_len = input_ids.shape @@ -952,53 +969,76 @@ def beam_search( beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + recent_tokens = torch.zeros(batch_size * num_beams) + for beam_group_idx in range(beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices to predict the next token on + batch_generation_indices = [] + for batch_idx in range(batch_size): + batch_generation_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]) + # predict next token for beams of current group for all sentences in the batch + group_input_ids = input_ids[batch_generation_indices, :] + + model_inputs = self.prepare_inputs_for_generation(group_input_ids, **model_kwargs) + + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + + # adjust tokens for Bart, *e.g.* + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) - outputs = self(**model_inputs, return_dict=True) - next_token_logits = outputs.logits[:, -1, :] + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + for batch_idx in range(batch_size): + previous_group_input_ids = input_ids[batch_idx * num_beams: batch_idx * num_beams + group_start_idx, -1] + token_frequency = torch.zeros(vocab_size) + for token in previous_group_input_ids: + token_frequency[token.item()] += 1 + next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] = next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] - diversity_penalty * token_frequency - # adjust tokens for Bart, *e.g.* - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=max_length - ) + next_token_scores = logits_processor(group_input_ids, next_token_scores) + next_token_scores = next_token_scores + beam_scores[batch_generation_indices].expand_as(next_token_scores) + # reshape for beam search - next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) - next_token_scores = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True - ) + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + group_input_ids, + group_size, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores[batch_generation_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] - next_indices = next_tokens // vocab_size - next_tokens = next_tokens % vocab_size + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + recent_tokens[batch_generation_indices] = group_input_ids[:, -1] - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + input_ids = torch.cat([input_ids, recent_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) - if beam_scorer.is_done: break From de6280f3bd4abc485e0f7aac6fb951df6691a0e9 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Wed, 18 Nov 2020 14:13:47 +0530 Subject: [PATCH 02/31] bug fixes --- src/transformers/generation_utils.py | 46 ++++++++++++++++------------ 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index d754ab025c4f..454ec5cda270 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -969,23 +969,25 @@ def beam_search( beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: - recent_tokens = torch.zeros(batch_size * num_beams) + recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self(**model_inputs, return_dict=True) + for beam_group_idx in range(beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx - # indices to predict the next token on - batch_generation_indices = [] + # indices of beams of current group among all sentences in batch + batch_group_indices = [] for batch_idx in range(batch_size): - batch_generation_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]) - # predict next token for beams of current group for all sentences in the batch - group_input_ids = input_ids[batch_generation_indices, :] - - model_inputs = self.prepare_inputs_for_generation(group_input_ids, **model_kwargs) + batch_group_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]) + group_input_ids = input_ids[batch_group_indices] - outputs = self(**model_inputs, return_dict=True) - next_token_logits = outputs.logits[:, -1, :] + # select outputs of beams of current group only + next_token_logits = outputs.logits[batch_group_indices, -1, :] # adjust tokens for Bart, *e.g.* next_token_logits = self.adjust_logits_during_generation( @@ -994,15 +996,19 @@ def beam_search( next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) vocab_size = next_token_scores.shape[-1] + + # hamming diversity: penalise using same token in current group which was used in previous groups at + # the same time step for batch_idx in range(batch_size): - previous_group_input_ids = input_ids[batch_idx * num_beams: batch_idx * num_beams + group_start_idx, -1] + # predicted tokens of last time step of previous groups + previous_group_tokens = recent_tokens[batch_idx * num_beams: batch_idx * num_beams + group_start_idx] token_frequency = torch.zeros(vocab_size) - for token in previous_group_input_ids: + for token in previous_group_tokens: token_frequency[token.item()] += 1 next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] = next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] - diversity_penalty * token_frequency next_token_scores = logits_processor(group_input_ids, next_token_scores) - next_token_scores = next_token_scores + beam_scores[batch_generation_indices].expand_as(next_token_scores) + next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(next_token_scores) # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) @@ -1024,18 +1030,18 @@ def beam_search( pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) - beam_scores[batch_generation_indices] = beam_outputs["next_beam_scores"] + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - recent_tokens[batch_generation_indices] = group_input_ids[:, -1] + recent_tokens[batch_group_indices] = group_input_ids[:, -1] - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) input_ids = torch.cat([input_ids, recent_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 From 3ab6551a41075da00e010958cb7caffe84fe1f3f Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Wed, 18 Nov 2020 17:10:07 +0530 Subject: [PATCH 03/31] bug fixes --- src/transformers/generation_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 454ec5cda270..9472d8515628 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -969,6 +969,7 @@ def beam_search( beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: + # predicted tokens in cur_len step recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype) # do one decoder step on all beams of all sentences in batch @@ -1034,6 +1035,7 @@ def beam_search( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] + input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) recent_tokens[batch_group_indices] = group_input_ids[:, -1] From da6654d0a3cc8a9708254ff3f60720a8028a2464 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Wed, 18 Nov 2020 23:11:23 +0530 Subject: [PATCH 04/31] bug fix --- src/transformers/generation_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 9472d8515628..17d7b6cfd81a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -964,8 +964,8 @@ def beam_search( num_beams * batch_size == batch_beam_size ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=input_ids.device) + beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: From 9e7d9766a5be74a2b0158c3a5511df7218dfed3b Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Fri, 20 Nov 2020 10:09:32 +0530 Subject: [PATCH 05/31] separate out diverse_beam_search function --- src/transformers/generation_beam_search.py | 9 ++------- src/transformers/generation_utils.py | 6 ++---- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 47a42f7e18d7..9f71ffa0ead4 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -146,9 +146,6 @@ class BeamSearchScorer(BeamScorer): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of beams. See `this paper `__ for more details. - diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): - This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. """ def __init__( @@ -161,7 +158,6 @@ def __init__( do_early_stopping: Optional[bool] = False, num_beam_hyps_to_keep: Optional[int] = 1, beam_groups: Optional[int] = 1, - diversity_penalty: Optional[float] = 0.0 ): self.max_length = max_length self.num_beams = num_beams @@ -170,7 +166,6 @@ def __init__( self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep self.beam_groups = beam_groups # when beam_groups=1 it is same as normal beam search - self.diversity_penalty = diversity_penalty self._is_init = False self._beam_hyps = [ @@ -196,13 +191,13 @@ def is_done(self) -> bool: def process( self, input_ids: torch.LongTensor, - group_size: int, next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None ) -> Tuple[torch.Tensor]: + group_size = self.num_beams // self.beam_groups cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) assert batch_size == (input_ids.shape[0] // group_size) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 17d7b6cfd81a..90614f8c27ab 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -557,8 +557,7 @@ def generate( length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, - beam_groups=beam_groups, - diversity_penalty=diversity_penalty + beam_groups=beam_groups ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -858,7 +857,7 @@ def sample( return input_ids - def beam_search( + def diverse_beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, @@ -1024,7 +1023,6 @@ def beam_search( # stateless beam_outputs = beam_scorer.process( group_input_ids, - group_size, next_token_scores, next_tokens, next_indices, From 69a91b401e7154d82a1f11573b4265b67e4cefb6 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Fri, 20 Nov 2020 11:03:05 +0530 Subject: [PATCH 06/31] separate out diverse_beam_search function --- src/transformers/generation_utils.py | 352 +++++++++++++++++++++------ 1 file changed, 276 insertions(+), 76 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 90614f8c27ab..ea17cd546da8 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -479,10 +479,13 @@ def generate( raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") # determine generation mode - is_greedy_gen_mode = (num_beams == 1) and do_sample is False - is_sample_gen_mode = (num_beams == 1) and do_sample is True - is_beam_gen_mode = (num_beams > 1) and do_sample is False - is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True + is_greedy_gen_mode = (num_beams == 1) and (beam_groups == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and (beam_groups == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and (beam_groups == 1) and do_sample is False + is_beam_sample_gen_mode = (num_beams > 1) and (beam_groups == 1) and do_sample is True + is_diverse_beam_gen_mode = (num_beams > 1) and (beam_groups > 1) + if beam_groups > num_beams: + raise ValueError(f"beam_groups has to be smaller or equal to num_beams") # set model_kwargs model_kwargs["use_cache"] = use_cache @@ -546,9 +549,6 @@ def generate( if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") - if num_beams % beam_groups != 0: - raise ValueError("`num_beams` should be divisible by `beam_groups` for diverse beam search.") - beam_scorer = BeamSearchScorer( batch_size=batch_size, max_length=max_length, @@ -556,8 +556,7 @@ def generate( device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - beam_groups=beam_groups + num_beam_hyps_to_keep=num_return_sequences ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -609,6 +608,43 @@ def generate( **model_kwargs, ) + elif is_diverse_beam_gen_mode: + batch_size = input_ids.shape[0] + + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if num_beams % beam_groups != 0: + raise ValueError("`num_beams` should be divisible by `beam_groups` for diverse beam search.") + + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + max_length=max_length, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + beam_groups=beam_groups + ) + # interleave with `num_beams` + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + return self.diverse_beam_search( + input_ids, + diversity_penalty, + beam_scorer, + logits_processor=logits_processor, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, + ) + def greedy_search( self, input_ids: torch.LongTensor, @@ -857,7 +893,7 @@ def sample( return input_ids - def diverse_beam_search( + def beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, @@ -953,9 +989,6 @@ def diverse_beam_search( batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams - beam_groups = beam_scorer.beam_groups - diversity_penalty = beam_scorer.diversity_penalty - num_sub_beams = num_beams // beam_groups batch_beam_size, cur_len = input_ids.shape @@ -963,79 +996,51 @@ def diverse_beam_search( num_beams * batch_size == batch_beam_size ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=input_ids.device) - beam_scores[:, ::num_sub_beams] = 0 + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: - # predicted tokens in cur_len step - recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype) - - # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self(**model_inputs, return_dict=True) - for beam_group_idx in range(beam_groups): - group_start_idx = beam_group_idx * num_sub_beams - group_end_idx = min(group_start_idx + num_sub_beams, num_beams) - group_size = group_end_idx - group_start_idx - - # indices of beams of current group among all sentences in batch - batch_group_indices = [] - for batch_idx in range(batch_size): - batch_group_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]) - group_input_ids = input_ids[batch_group_indices] - - # select outputs of beams of current group only - next_token_logits = outputs.logits[batch_group_indices, -1, :] - - # adjust tokens for Bart, *e.g.* - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=max_length - ) - - next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) - vocab_size = next_token_scores.shape[-1] + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] - # hamming diversity: penalise using same token in current group which was used in previous groups at - # the same time step - for batch_idx in range(batch_size): - # predicted tokens of last time step of previous groups - previous_group_tokens = recent_tokens[batch_idx * num_beams: batch_idx * num_beams + group_start_idx] - token_frequency = torch.zeros(vocab_size) - for token in previous_group_tokens: - token_frequency[token.item()] += 1 - next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] = next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] - diversity_penalty * token_frequency + # adjust tokens for Bart, *e.g.* + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) - next_token_scores = logits_processor(group_input_ids, next_token_scores) - next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(next_token_scores) - # reshape for beam search + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) - next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + next_token_scores = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True - ) + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) - next_indices = next_tokens // vocab_size - next_tokens = next_tokens % vocab_size + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size - # stateless - beam_outputs = beam_scorer.process( - group_input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - recent_tokens[batch_group_indices] = group_input_ids[:, -1] + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder @@ -1043,8 +1048,6 @@ def diverse_beam_search( if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) - input_ids = torch.cat([input_ids, recent_tokens.unsqueeze(-1)], dim=-1) - cur_len = cur_len + 1 if beam_scorer.is_done: break @@ -1231,6 +1234,203 @@ def beam_sample( return decoded + def diverse_beam_search( + self, + input_ids: torch.LongTensor, + diversity_penalty: float, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): + r""" + Generates sequences for models with a language modeling head using beam search decoding. + + Parameters: + + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty + :obj:`torch.LongTensor` of shape :obj:`(1,)`. + beam_scorer (:obj:`BeamScorer`): + An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are + constructed, stored and sorted during generation. For more information, the documentation of + :class:`~transformers.BeamScorer` should be read. + logits_processor (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling + head applied at each generation step. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + model_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If + model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + + Examples:: + + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList([ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ]) + + >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + """ + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + beam_groups = beam_scorer.beam_groups + num_sub_beams = num_beams // beam_groups + + batch_beam_size, cur_len = input_ids.shape + + assert ( + num_beams * batch_size == batch_beam_size + ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=input_ids.device) + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while cur_len < max_length: + # predicted tokens in cur_len step + recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self(**model_inputs, return_dict=True) + + for beam_group_idx in range(beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + for batch_idx in range(batch_size): + batch_group_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + next_token_logits = outputs.logits[batch_group_indices, -1, :] + + # adjust tokens for Bart, *e.g.* + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) + + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + + # hamming diversity: penalise using same token in current group which was used in previous groups at + # the same time step + for batch_idx in range(batch_size): + # predicted tokens of last time step of previous groups + previous_group_tokens = recent_tokens[batch_idx * num_beams: batch_idx * num_beams + group_start_idx] + token_frequency = torch.zeros(vocab_size) + for token in previous_group_tokens: + token_frequency[token.item()] += 1 + next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] = next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] - diversity_penalty * token_frequency + + next_token_scores = logits_processor(group_input_ids, next_token_scores) + next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(next_token_scores) + # reshape for beam search + + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + recent_tokens[batch_group_indices] = group_input_ids[:, -1] + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + + input_ids = torch.cat([input_ids, recent_tokens.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 + if beam_scorer.is_done: + break + + decoded = beam_scorer.finalize( + input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id + ) + + return decoded + def top_k_top_p_filtering( logits: torch.FloatTensor, From a1b57d2d4bc1546e1c0fde88dc4d8e6b0287df57 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Fri, 20 Nov 2020 11:04:45 +0530 Subject: [PATCH 07/31] bug fix --- src/transformers/generation_beam_search.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 9f71ffa0ead4..3820f4cfa86f 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -95,7 +95,6 @@ class BeamScorer(ABC): def process( self, input_ids: torch.LongTensor, - group_size: int, next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, From 8ffe8fdc500ecbe9f5cdbc6b97e3e04b1db6e46b Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Fri, 20 Nov 2020 12:04:39 +0530 Subject: [PATCH 08/31] improve code quality --- src/transformers/generation_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index ea17cd546da8..82980faa08b9 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1253,6 +1253,9 @@ def diverse_beam_search( input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. + diversity_penalty (:obj:`float`): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. beam_scorer (:obj:`BeamScorer`): An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of @@ -1341,6 +1344,8 @@ def diverse_beam_search( ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=input_ids.device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) @@ -1379,10 +1384,8 @@ def diverse_beam_search( for batch_idx in range(batch_size): # predicted tokens of last time step of previous groups previous_group_tokens = recent_tokens[batch_idx * num_beams: batch_idx * num_beams + group_start_idx] - token_frequency = torch.zeros(vocab_size) - for token in previous_group_tokens: - token_frequency[token.item()] += 1 - next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] = next_token_scores[batch_idx * group_size : (batch_idx + 1)*group_size] - diversity_penalty * token_frequency + token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size) + next_token_scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= diversity_penalty * token_frequency next_token_scores = logits_processor(group_input_ids, next_token_scores) next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(next_token_scores) From 191a59daf65cfb4e9f87de77406bd0622bbb1aa1 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Fri, 20 Nov 2020 13:50:53 +0530 Subject: [PATCH 09/31] bug fix --- src/transformers/generation_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 82980faa08b9..d5a1b8e64a27 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1336,6 +1336,7 @@ def diverse_beam_search( num_beams = beam_scorer.num_beams beam_groups = beam_scorer.beam_groups num_sub_beams = num_beams // beam_groups + device = input_ids.device batch_beam_size, cur_len = input_ids.shape @@ -1343,7 +1344,7 @@ def diverse_beam_search( num_beams * batch_size == batch_beam_size ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=input_ids.device) + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 @@ -1351,7 +1352,7 @@ def diverse_beam_search( while cur_len < max_length: # predicted tokens in cur_len step - recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype) + recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -1384,7 +1385,7 @@ def diverse_beam_search( for batch_idx in range(batch_size): # predicted tokens of last time step of previous groups previous_group_tokens = recent_tokens[batch_idx * num_beams: batch_idx * num_beams + group_start_idx] - token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size) + token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(device) next_token_scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= diversity_penalty * token_frequency next_token_scores = logits_processor(group_input_ids, next_token_scores) From ea945b2c1cc890fad357f3e367f4b034a56f61d1 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Fri, 20 Nov 2020 14:56:29 +0530 Subject: [PATCH 10/31] bug fix --- src/transformers/generation_utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index d5a1b8e64a27..45d2b9190cf4 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1297,8 +1297,8 @@ def diverse_beam_search( >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - >>> # lets run beam search using 3 beams - >>> num_beams = 3 + >>> # lets run diverse beam search using 6 beams + >>> num_beams = 6 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id @@ -1314,6 +1314,7 @@ def diverse_beam_search( ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, + ... beam_groups=3 ... ) >>> # instantiate logits processors @@ -1321,7 +1322,7 @@ def diverse_beam_search( ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ]) - >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + >>> outputs = model.diverse_beam_search(input_ids, 5.5, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ @@ -1354,6 +1355,9 @@ def diverse_beam_search( # predicted tokens in cur_len step recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self(**model_inputs, return_dict=True) @@ -1418,11 +1422,15 @@ def diverse_beam_search( group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) recent_tokens[batch_group_indices] = group_input_ids[:, -1] + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) input_ids = torch.cat([input_ids, recent_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 From 4683c043dc0dc76d29b099e2ef8f87ec6a3abc4b Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 21 Nov 2020 10:47:09 +0530 Subject: [PATCH 11/31] separate out diverse beam search scorer --- src/transformers/__init__.py | 2 +- src/transformers/generation_beam_search.py | 222 +++++++++++++++++++-- src/transformers/generation_utils.py | 12 +- 3 files changed, 215 insertions(+), 21 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a92fb488125b..351c4400789f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -294,7 +294,7 @@ TextDataset, TextDatasetForNextSentencePrediction, ) - from .generation_beam_search import BeamScorer, BeamSearchScorer + from .generation_beam_search import BeamScorer, BeamSearchScorer, DiverseBeamSearchScorer from .generation_logits_process import ( LogitsProcessor, LogitsProcessorList, diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 3820f4cfa86f..96a8857f0492 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -141,10 +141,204 @@ class BeamSearchScorer(BeamScorer): num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): The number of beam hypotheses that shall be returned upon calling :meth:`~transformer.BeamSearchScorer.finalize`. - beam_groups (:obj:`int`, `optional`, defaults to 1): + """ + + def __init__( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + max_length=self.max_length, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> Tuple[torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + assert batch_size == (input_ids.shape[0] // self.num_beams) + + device = input_ids.device + next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + assert ( + len(beam_hyp) >= self.num_beams + ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.num_beams + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.num_beams: + break + + if beam_idx < self.num_beams: + raise ValueError( + f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> torch.LongTensor: + batch_size = len(self._beam_hyps) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add(final_tokens, final_score) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp = sorted_hyps.pop()[1] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + best.append(best_hyp) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < self.max_length: + decoded[i, sent_lengths[i]] = eos_token_id + return decoded + + +class DiverseBeamSearchScorer(BeamScorer): + r""" + :class:`transformers.BeamScorer` implementing diverse beam search decoding. + + Adapted in part from `Facebook's XLM beam search code + `__. + + Args: + batch_size (:obj:`int`): + Batch Size of :obj:`input_ids` for which diverse beam search decoding is run in parallel. + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + num_beams (:obj:`int`): + Number of beams for beam search. + beam_groups (:obj:`int`): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of beams. See `this paper `__ for more details. + device (:obj:`torch.device`): + Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of + :obj:`BeamSearchScorer` will be allocated. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + :meth:`~transformer.BeamSearchScorer.finalize`. """ def __init__( @@ -152,11 +346,11 @@ def __init__( batch_size: int, max_length: int, num_beams: int, + beam_groups: int, device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, - num_beam_hyps_to_keep: Optional[int] = 1, - beam_groups: Optional[int] = 1, + num_beam_hyps_to_keep: Optional[int] = 1 ): self.max_length = max_length self.num_beams = num_beams @@ -164,7 +358,8 @@ def __init__( self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - self.beam_groups = beam_groups # when beam_groups=1 it is same as normal beam search + self.beam_groups = beam_groups + self.group_size = self.num_beams // self.beam_groups self._is_init = False self._beam_hyps = [ @@ -196,15 +391,14 @@ def process( pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None ) -> Tuple[torch.Tensor]: - group_size = self.num_beams // self.beam_groups cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // group_size) + assert batch_size == (input_ids.shape[0] // self.group_size) device = input_ids.device - next_beam_scores = torch.zeros((batch_size, group_size), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, group_size), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, group_size), dtype=next_indices.dtype, device=device) + next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -225,11 +419,11 @@ def process( for beam_token_rank, (next_token, next_score, next_index) in enumerate( zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) ): - batch_beam_idx = batch_idx * group_size + next_index + batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence if (eos_token_id is not None) and (next_token.item() == eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= group_size + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: continue beam_hyp.add( @@ -244,12 +438,12 @@ def process( beam_idx += 1 # once the beam for next step is full, don't add more tokens to it. - if beam_idx == group_size: + if beam_idx == self.group_size: break - if beam_idx < group_size: + if beam_idx < self.group_size: raise ValueError( - f"At most {group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." ) # Check if we are done so that we can save a pad step if all(done) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index c8e50b62e53f..747fda2c0f75 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -20,7 +20,7 @@ from torch.nn import functional as F from .file_utils import ModelOutput -from .generation_beam_search import BeamScorer, BeamSearchScorer +from .generation_beam_search import BeamScorer, BeamSearchScorer, DiverseBeamSearchScorer from .generation_logits_process import ( LogitsProcessorList, MinLengthLogitsProcessor, @@ -644,15 +644,15 @@ def generate( if num_beams % beam_groups != 0: raise ValueError("`num_beams` should be divisible by `beam_groups` for diverse beam search.") - beam_scorer = BeamSearchScorer( + diverse_beam_scorer = DiverseBeamSearchScorer( batch_size=batch_size, max_length=max_length, num_beams=num_beams, + beam_groups=beam_groups, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - beam_groups=beam_groups + num_beam_hyps_to_keep=num_return_sequences ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -661,7 +661,7 @@ def generate( return self.diverse_beam_search( input_ids, diversity_penalty, - beam_scorer, + diverse_beam_scorer, logits_processor=logits_processor, max_length=max_length, pad_token_id=pad_token_id, @@ -1333,7 +1333,7 @@ def diverse_beam_search( ... } >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( + >>> beam_scorer = DiverseBeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, From 77a861ad4c704abc4bd82e98c0f22067edb6f890 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 21 Nov 2020 11:06:23 +0530 Subject: [PATCH 12/31] code format --- src/transformers/generation_beam_search.py | 4 ++-- src/transformers/generation_utils.py | 24 +++++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 96a8857f0492..41fec7d304b6 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -350,7 +350,7 @@ def __init__( device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, - num_beam_hyps_to_keep: Optional[int] = 1 + num_beam_hyps_to_keep: Optional[int] = 1, ): self.max_length = max_length self.num_beams = num_beams @@ -389,7 +389,7 @@ def process( next_tokens: torch.LongTensor, next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None + eos_token_id: Optional[int] = None, ) -> Tuple[torch.Tensor]: cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 747fda2c0f75..45dd7068b43f 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -580,7 +580,7 @@ def generate( device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences + num_beam_hyps_to_keep=num_return_sequences, ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -652,7 +652,7 @@ def generate( device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences + num_beam_hyps_to_keep=num_return_sequences, ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -1394,7 +1394,9 @@ def diverse_beam_search( # indices of beams of current group among all sentences in batch batch_group_indices = [] for batch_idx in range(batch_size): - batch_group_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]) + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of current group only @@ -1412,12 +1414,18 @@ def diverse_beam_search( # the same time step for batch_idx in range(batch_size): # predicted tokens of last time step of previous groups - previous_group_tokens = recent_tokens[batch_idx * num_beams: batch_idx * num_beams + group_start_idx] + previous_group_tokens = recent_tokens[ + batch_idx * num_beams : batch_idx * num_beams + group_start_idx + ] token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(device) - next_token_scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= diversity_penalty * token_frequency + next_token_scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= ( + diversity_penalty * token_frequency + ) next_token_scores = logits_processor(group_input_ids, next_token_scores) - next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(next_token_scores) + next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( + next_token_scores + ) # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) @@ -1448,7 +1456,9 @@ def diverse_beam_search( # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) + reordering_indices[batch_group_indices] = ( + num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) + ) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder From 26eb8a8c212fdfaec7f599197796fa4e3b399d84 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 21 Nov 2020 11:27:35 +0530 Subject: [PATCH 13/31] code format --- src/transformers/generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 45dd7068b43f..ae41c4c2c514 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -507,7 +507,7 @@ def generate( is_beam_sample_gen_mode = (num_beams > 1) and (beam_groups == 1) and do_sample is True is_diverse_beam_gen_mode = (num_beams > 1) and (beam_groups > 1) if beam_groups > num_beams: - raise ValueError(f"beam_groups has to be smaller or equal to num_beams") + raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`") # set model_kwargs model_kwargs["use_cache"] = use_cache From c35492b7814e2e3ce66b2b9989fd3b5e362db75a Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 21 Nov 2020 11:42:00 +0530 Subject: [PATCH 14/31] code format --- src/transformers/generation_beam_search.py | 5 ++--- src/transformers/generation_utils.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 41fec7d304b6..8287c7525e4b 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -324,9 +324,8 @@ class DiverseBeamSearchScorer(BeamScorer): num_beams (:obj:`int`): Number of beams for beam search. beam_groups (:obj:`int`): - Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups - of beams. See `this paper - `__ for more details. + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. See `this paper `__ for more details. device (:obj:`torch.device`): Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of :obj:`BeamSearchScorer` will be allocated. diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index ae41c4c2c514..d608d87fa567 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -384,9 +384,8 @@ def generate( Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. beam_groups (:obj:`int`, `optional`, defaults to 1): - Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups - of beams. See `this paper - `__ for more details. + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. See `this paper `__ for more details. diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. From f63017ab9e5a42c48c21f93d9c2d73d68dae5e46 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 21 Nov 2020 11:44:34 +0530 Subject: [PATCH 15/31] code format --- src/transformers/utils/dummy_pt_objects.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b0e81bd8cbc1..2d9efa7c4985 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -118,6 +118,11 @@ def __init__(self, *args, **kwargs): requires_pytorch(self) +class DiverseBeamSearchScorer: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class LogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) From 217a1fa80b60900f90f1ae7c04775aa23f01269c Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 21 Nov 2020 15:07:18 +0530 Subject: [PATCH 16/31] add test --- tests/test_generation_utils.py | 112 ++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 433dad34e680..d15f2a8544fe 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -24,7 +24,7 @@ import torch from transformers import top_k_top_p_filtering - from transformers.generation_beam_search import BeamSearchScorer + from transformers.generation_beam_search import BeamSearchScorer, DiverseBeamSearchScorer from transformers.generation_logits_process import ( LogitsProcessorList, MinLengthLogitsProcessor, @@ -115,6 +115,28 @@ def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): ) return beam_kwargs, beam_scorer + @staticmethod + def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): + beam_kwargs = { + "early_stopping": False, + "length_penalty": 2.0, + "num_beams": 2, + "num_return_sequences": num_return_sequences, + "beam_groups": 2, # one beam per group + "diversity_penalty": 1.0, + } + beam_scorer = DiverseBeamSearchScorer( + batch_size=batch_size, + max_length=max_length, + num_beams=beam_kwargs["num_beams"], + beam_groups=beam_kwargs["beam_groups"], + device=torch_device, + length_penalty=beam_kwargs["length_penalty"], + do_early_stopping=beam_kwargs["early_stopping"], + num_beam_hyps_to_keep=num_return_sequences, + ) + return beam_kwargs, beam_scorer + @staticmethod def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1): encoder = model.get_encoder() @@ -408,6 +430,94 @@ def test_generate_without_input_ids(self): self.assertIsNotNone(output_ids_generate) + def test_diverse_beam_search_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id + ) + + model = model_class(config).to(torch_device) + model.eval() + + # check `generate()` and `diverse_beam_search()` are equal + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + **beam_kwargs, + **logits_process_kwargs, + ) + + # diverse_beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_ids_diverse_beam_search = model.diverse_beam_search( + input_ids_clone, + beam_kwargs["diversity_penalty"], + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_diverse_beam_search.tolist()) + + # check `generate()` and `diverse_beam_search()` are equal for `num_return_sequences` + num_return_sequences = 2 + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, num_return_sequences=num_return_sequences + ) + + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + **beam_kwargs, + **logits_process_kwargs, + ) + # diverse_beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_ids_beam_search = model.diverse_beam_search( + input_ids_clone, + beam_kwargs["diversity_penalty"], + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist()) + @require_torch class UtilsFunctionsTest(unittest.TestCase): From f08cfe50c84f6593e42eae8575688afb86231d9c Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 21 Nov 2020 15:26:51 +0530 Subject: [PATCH 17/31] code format --- tests/test_generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index d15f2a8544fe..bce2acad2220 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -122,7 +122,7 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque "length_penalty": 2.0, "num_beams": 2, "num_return_sequences": num_return_sequences, - "beam_groups": 2, # one beam per group + "beam_groups": 2, # one beam per group "diversity_penalty": 1.0, } beam_scorer = DiverseBeamSearchScorer( From 9d702061b3d4cdb239c062726e0af4b6e6f1bb81 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 5 Dec 2020 13:34:57 +0530 Subject: [PATCH 18/31] documentation changes --- src/transformers/generation_beam_search.py | 2 ++ src/transformers/generation_utils.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 8287c7525e4b..54ed92d4840c 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -316,6 +316,8 @@ class DiverseBeamSearchScorer(BeamScorer): Adapted in part from `Facebook's XLM beam search code `__. + Reference for the diverse beam search algorithm and implementation + `Ashwin Kalyan's DBS implementation `__. Args: batch_size (:obj:`int`): Batch Size of :obj:`input_ids` for which diverse beam search decoding is run in parallel. diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index d608d87fa567..978cd16eef02 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -385,10 +385,12 @@ def generate( speed up decoding. beam_groups (:obj:`int`, `optional`, defaults to 1): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of - beams. See `this paper `__ for more details. + beams. To enable diverse beam search, :obj:`beam_groups` should be set to a value larger than 1. + See `this paper `__ for more details. diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. + at a particular time. Note that :obj:`diversity_penalty` is only effective if ``diverse beam search`` + is enabled. prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID @@ -507,6 +509,10 @@ def generate( is_diverse_beam_gen_mode = (num_beams > 1) and (beam_groups > 1) if beam_groups > num_beams: raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`") + if is_diverse_beam_gen_mode and do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) # set model_kwargs model_kwargs["use_cache"] = use_cache @@ -1309,7 +1315,7 @@ def diverse_beam_search( ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, - ... BeamSearchScorer, + ... DiverseBeamSearchScorer, ... ) >>> import torch From 9b61e09da42d00dcc41ccab0f506bd0dd454168d Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Sat, 5 Dec 2020 13:42:16 +0530 Subject: [PATCH 19/31] code quality --- src/transformers/generation_beam_search.py | 5 +++-- src/transformers/generation_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 54ed92d4840c..46f4b8355ee6 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -316,8 +316,9 @@ class DiverseBeamSearchScorer(BeamScorer): Adapted in part from `Facebook's XLM beam search code `__. - Reference for the diverse beam search algorithm and implementation - `Ashwin Kalyan's DBS implementation `__. + Reference for the diverse beam search algorithm and implementation `Ashwin Kalyan's DBS implementation + `__ + Args: batch_size (:obj:`int`): Batch Size of :obj:`input_ids` for which diverse beam search decoding is run in parallel. diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 978cd16eef02..397a4141af09 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -385,8 +385,8 @@ def generate( speed up decoding. beam_groups (:obj:`int`, `optional`, defaults to 1): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of - beams. To enable diverse beam search, :obj:`beam_groups` should be set to a value larger than 1. - See `this paper `__ for more details. + beams. To enable diverse beam search, :obj:`beam_groups` should be set to a value larger than 1. See + `this paper `__ for more details. diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that :obj:`diversity_penalty` is only effective if ``diverse beam search`` From b1d2269cb9e3d1e80a9fd98acb2678a5c4a7518b Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 09:17:23 +0000 Subject: [PATCH 20/31] add slow integration tests --- tests/test_generation_utils.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index bce2acad2220..bce1a23291cc 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -17,13 +17,13 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch, slow, torch_device if is_torch_available(): import torch - from transformers import top_k_top_p_filtering + from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering from transformers.generation_beam_search import BeamSearchScorer, DiverseBeamSearchScorer from transformers.generation_logits_process import ( LogitsProcessorList, @@ -622,3 +622,31 @@ def test_top_k_top_p_filtering(self): self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) + + +@require_torch +class GenerationIntegrationTests(unittest.TestCase): + @slow + def test_diverse_beam_search(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood. + The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. + "Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. + The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both.""" + + bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = bart_model.generate( + input_ids, num_beams=4, num_return_sequences=2, beam_groups=4, diversity_penalty=2.0 + ) + + generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.", + "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.", + ], + ) From 26f9e0fc1be195bc6f330bedccb051ad48b206dd Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 09:20:36 +0000 Subject: [PATCH 21/31] more general name --- src/transformers/__init__.py | 2 +- src/transformers/generation_beam_search.py | 2 +- src/transformers/generation_utils.py | 20 ++++++++++---------- src/transformers/utils/dummy_pt_objects.py | 2 +- tests/test_generation_utils.py | 20 ++++++++++---------- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 351c4400789f..24ec177ead28 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -294,7 +294,7 @@ TextDataset, TextDatasetForNextSentencePrediction, ) - from .generation_beam_search import BeamScorer, BeamSearchScorer, DiverseBeamSearchScorer + from .generation_beam_search import BeamScorer, BeamSearchScorer, GroupBeamScorer from .generation_logits_process import ( LogitsProcessor, LogitsProcessorList, diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 46f4b8355ee6..e0b3164edf5c 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -309,7 +309,7 @@ def finalize( return decoded -class DiverseBeamSearchScorer(BeamScorer): +class GroupBeamScorer(BeamScorer): r""" :class:`transformers.BeamScorer` implementing diverse beam search decoding. diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 397a4141af09..811cfca7e285 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -20,7 +20,7 @@ from torch.nn import functional as F from .file_utils import ModelOutput -from .generation_beam_search import BeamScorer, BeamSearchScorer, DiverseBeamSearchScorer +from .generation_beam_search import BeamScorer, BeamSearchScorer, GroupBeamScorer from .generation_logits_process import ( LogitsProcessorList, MinLengthLogitsProcessor, @@ -506,10 +506,10 @@ def generate( is_sample_gen_mode = (num_beams == 1) and (beam_groups == 1) and do_sample is True is_beam_gen_mode = (num_beams > 1) and (beam_groups == 1) and do_sample is False is_beam_sample_gen_mode = (num_beams > 1) and (beam_groups == 1) and do_sample is True - is_diverse_beam_gen_mode = (num_beams > 1) and (beam_groups > 1) + is_group_beam_gen_mode = (num_beams > 1) and (beam_groups > 1) if beam_groups > num_beams: raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`") - if is_diverse_beam_gen_mode and do_sample is True: + if is_group_beam_gen_mode and do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) @@ -637,7 +637,7 @@ def generate( **model_kwargs, ) - elif is_diverse_beam_gen_mode: + elif is_group_beam_gen_mode: batch_size = input_ids.shape[0] length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty @@ -649,7 +649,7 @@ def generate( if num_beams % beam_groups != 0: raise ValueError("`num_beams` should be divisible by `beam_groups` for diverse beam search.") - diverse_beam_scorer = DiverseBeamSearchScorer( + diverse_beam_scorer = GroupBeamScorer( batch_size=batch_size, max_length=max_length, num_beams=num_beams, @@ -663,7 +663,7 @@ def generate( input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) - return self.diverse_beam_search( + return self.group_beam_search( input_ids, diversity_penalty, diverse_beam_scorer, @@ -1263,7 +1263,7 @@ def beam_sample( return decoded - def diverse_beam_search( + def group_beam_search( self, input_ids: torch.LongTensor, diversity_penalty: float, @@ -1315,7 +1315,7 @@ def diverse_beam_search( ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, - ... DiverseBeamSearchScorer, + ... GroupBeamScorer, ... ) >>> import torch @@ -1338,7 +1338,7 @@ def diverse_beam_search( ... } >>> # instantiate beam scorer - >>> beam_scorer = DiverseBeamSearchScorer( + >>> beam_scorer = GroupBeamScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, @@ -1351,7 +1351,7 @@ def diverse_beam_search( ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ]) - >>> outputs = model.diverse_beam_search(input_ids, 5.5, beam_scorer, logits_processor=logits_processor, **model_kwargs) + >>> outputs = model.group_beam_search(input_ids, 5.5, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2d9efa7c4985..2ebab93b8b83 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs): requires_pytorch(self) -class DiverseBeamSearchScorer: +class GroupBeamScorer: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index bce1a23291cc..eac770aad0fb 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -24,7 +24,7 @@ import torch from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering - from transformers.generation_beam_search import BeamSearchScorer, DiverseBeamSearchScorer + from transformers.generation_beam_search import BeamSearchScorer, GroupBeamScorer from transformers.generation_logits_process import ( LogitsProcessorList, MinLengthLogitsProcessor, @@ -125,7 +125,7 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque "beam_groups": 2, # one beam per group "diversity_penalty": 1.0, } - beam_scorer = DiverseBeamSearchScorer( + beam_scorer = GroupBeamScorer( batch_size=batch_size, max_length=max_length, num_beams=beam_kwargs["num_beams"], @@ -430,7 +430,7 @@ def test_generate_without_input_ids(self): self.assertIsNotNone(output_ids_generate) - def test_diverse_beam_search_generate(self): + def test_group_beam_search_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() @@ -441,7 +441,7 @@ def test_diverse_beam_search_generate(self): model = model_class(config).to(torch_device) model.eval() - # check `generate()` and `diverse_beam_search()` are equal + # check `generate()` and `group_beam_search()` are equal if model.config.is_encoder_decoder: max_length = 4 beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) @@ -454,7 +454,7 @@ def test_diverse_beam_search_generate(self): **logits_process_kwargs, ) - # diverse_beam_search does not automatically interleave `batch_size` dim for `num_beams` + # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( @@ -467,7 +467,7 @@ def test_diverse_beam_search_generate(self): input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) with torch.no_grad(): - output_ids_diverse_beam_search = model.diverse_beam_search( + output_ids_group_beam_search = model.group_beam_search( input_ids_clone, beam_kwargs["diversity_penalty"], beam_scorer, @@ -476,9 +476,9 @@ def test_diverse_beam_search_generate(self): logits_processor=logits_processor, **kwargs, ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_diverse_beam_search.tolist()) + self.assertListEqual(output_ids_generate.tolist(), output_ids_group_beam_search.tolist()) - # check `generate()` and `diverse_beam_search()` are equal for `num_return_sequences` + # check `generate()` and `group_beam_search()` are equal for `num_return_sequences` num_return_sequences = 2 if model.config.is_encoder_decoder: max_length = 4 @@ -494,7 +494,7 @@ def test_diverse_beam_search_generate(self): **beam_kwargs, **logits_process_kwargs, ) - # diverse_beam_search does not automatically interleave `batch_size` dim for `num_beams` + # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` kwargs = {} if model.config.is_encoder_decoder: encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( @@ -507,7 +507,7 @@ def test_diverse_beam_search_generate(self): input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) with torch.no_grad(): - output_ids_beam_search = model.diverse_beam_search( + output_ids_beam_search = model.group_beam_search( input_ids_clone, beam_kwargs["diversity_penalty"], beam_scorer, From fb433cd3b8fd79a47985469b1ec49b4f6963b33d Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 10:44:35 +0000 Subject: [PATCH 22/31] refactor into logits processor --- src/transformers/configuration_utils.py | 8 ++ src/transformers/generation_beam_search.py | 8 +- src/transformers/generation_logits_process.py | 73 +++++++++++++++++- src/transformers/generation_utils.py | 77 +++++++++---------- tests/test_generation_utils.py | 22 ++++-- 5 files changed, 134 insertions(+), 54 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f7587faac0a2..1afc1860da61 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -97,6 +97,12 @@ class PretrainedConfig(object): sentences are finished per batch or not. - **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by default in the :obj:`generate` method of the model. 1 means no beam search. + - **num_beam_groups** (:obj:`int`, `optional`, defaults to 1) -- Number of groups to divide :obj:`num_beams` + into in order to ensure diversity among different groups of beams that will be used by default in the + :obj:`generate` method of the model. 1 means no group beam search. + - **diversity_penalty** (:obj:`float`, `optional`, defaults to 0.0) -- Value to control diversity for group + beam search. that will be used by default in the :obj:`generate` method of the model. 0 means no diversity + penalty. The higher the penalty, the more diverse are the outputs. - **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly positive. @@ -188,6 +194,8 @@ def __init__(self, **kwargs): self.do_sample = kwargs.pop("do_sample", False) self.early_stopping = kwargs.pop("early_stopping", False) self.num_beams = kwargs.pop("num_beams", 1) + self.num_beam_groups = kwargs.pop("num_beam_groups", 1) + self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) self.temperature = kwargs.pop("temperature", 1.0) self.top_k = kwargs.pop("top_k", 50) self.top_p = kwargs.pop("top_p", 1.0) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index e0b3164edf5c..486009012a61 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -326,7 +326,7 @@ class GroupBeamScorer(BeamScorer): The maximum length of the sequence to be generated. num_beams (:obj:`int`): Number of beams for beam search. - beam_groups (:obj:`int`): + num_beam_groups (:obj:`int`): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of beams. See `this paper `__ for more details. device (:obj:`torch.device`): @@ -348,7 +348,7 @@ def __init__( batch_size: int, max_length: int, num_beams: int, - beam_groups: int, + num_beam_groups: int, device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, @@ -360,8 +360,8 @@ def __init__( self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - self.beam_groups = beam_groups - self.group_size = self.num_beams // self.beam_groups + self.num_beam_groups = num_beam_groups + self.group_size = self.num_beams // self.num_beam_groups self._is_init = False self._beam_hyps = [ diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index dc6b183c4f5b..7b3925e579d9 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math from abc import ABC from typing import Callable, Iterable, List @@ -37,6 +38,8 @@ scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax or scores for each vocabulary token after SoftMax. + kwargs: + Additional logits processor specific kwargs. Return: :obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. @@ -75,9 +78,16 @@ class LogitsProcessorList(list): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: for processor in self: - scores = processor(input_ids, scores) + function_args = inspect.signature(processor.__call__).parameters + if len(function_args) > 2: + assert all( + arg in kwargs for arg in list(function_args.keys())[2:] + ), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." + scores = processor(input_ids, scores, **kwargs) + else: + scores = processor(input_ids, scores) return scores @@ -400,3 +410,62 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 return scores + mask + + +class HammingDiversityLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only + effective for `group_beam_search`. See `Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models + `__ for more details. + + Args: + diversity_penalty (:obj:`float`, `optional`): + This value is subtracted from a beam's score if it generates a token same as any beam from other group at a + particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled. + num_beams (:obj:`int`, `optional`): + Number of beams used for group beam search. See `this paper `__ for + more details. + num_beam_groups (:obj:`int`, `optional`): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. See `this paper `__ for more details. + """ + + def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int): + if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0): + raise ValueError("`diversity_penalty` should be a float strictly larger than 0.") + self._diversity_penalty = diversity_penalty + if not isinstance(num_beams, int) or num_beams < 2: + raise ValueError("`num_beams` should be an integer strictly larger than 1.") + self._num_beams = num_beams + if not isinstance(num_beam_groups, int) or num_beam_groups < 2: + raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") + if num_beam_groups > num_beams: + raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.") + if num_beam_groups > num_beams: + raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`") + self._num_sub_beams = num_beams // num_beam_groups + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + recent_tokens: torch.LongTensor, + beam_group_idx: int, + ) -> torch.FloatTensor: + # hamming diversity: penalise using same token in current group which was used in previous groups at + # the same time step + batch_size = recent_tokens.shape[0] // self._num_beams + group_start_idx = beam_group_idx * self._num_sub_beams + group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) + group_size = group_end_idx - group_start_idx + vocab_size = scores.shape[-1] + + for batch_idx in range(batch_size): + # predicted tokens of last time step of previous groups + previous_group_tokens = recent_tokens[ + batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx + ] + token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device) + scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency + + return scores diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 811cfca7e285..c14cae57646d 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -22,6 +22,7 @@ from .file_utils import ModelOutput from .generation_beam_search import BeamScorer, BeamSearchScorer, GroupBeamScorer from .generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -261,6 +262,8 @@ def _get_logits_processor( eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, + num_beam_groups: int, + diversity_penalty: float, ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant @@ -275,11 +278,18 @@ def _get_logits_processor( bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty # instantiate processors list processors = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` + if diversity_penalty is not None and diversity_penalty > 0.0: + processors.append( + HammingDiversityLogitsProcessor( + diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups + ) + ) if repetition_penalty is not None and repetition_penalty != 1.0: processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: @@ -314,8 +324,8 @@ def generate( num_return_sequences: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, - beam_groups: Optional[int] = 1, - diversity_penalty: Optional[float] = 0.0, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **model_kwargs ) -> torch.LongTensor: @@ -383,14 +393,13 @@ def generate( use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. - beam_groups (:obj:`int`, `optional`, defaults to 1): + num_beam_groups (:obj:`int`, `optional`, defaults to 1): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of - beams. To enable diverse beam search, :obj:`beam_groups` should be set to a value larger than 1. See - `this paper `__ for more details. + beams. `this paper `__ for more details. diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. Note that :obj:`diversity_penalty` is only effective if ``diverse beam search`` - is enabled. + at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is + enabled. prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID @@ -463,6 +472,7 @@ def generate( # set init values num_beams = num_beams if num_beams is not None else self.config.num_beams + num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups max_length = max_length if max_length is not None else self.config.max_length do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( @@ -502,13 +512,13 @@ def generate( raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") # determine generation mode - is_greedy_gen_mode = (num_beams == 1) and (beam_groups == 1) and do_sample is False - is_sample_gen_mode = (num_beams == 1) and (beam_groups == 1) and do_sample is True - is_beam_gen_mode = (num_beams > 1) and (beam_groups == 1) and do_sample is False - is_beam_sample_gen_mode = (num_beams > 1) and (beam_groups == 1) and do_sample is True - is_group_beam_gen_mode = (num_beams > 1) and (beam_groups > 1) - if beam_groups > num_beams: - raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`") + is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False + is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) + if num_beam_groups > num_beams: + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." @@ -526,6 +536,8 @@ def generate( eos_token_id=eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, + num_beam_groups=num_beam_groups, + diversity_penalty=diversity_penalty, ) if is_greedy_gen_mode: @@ -646,14 +658,14 @@ def generate( if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") - if num_beams % beam_groups != 0: - raise ValueError("`num_beams` should be divisible by `beam_groups` for diverse beam search.") + if num_beams % num_beam_groups != 0: + raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") diverse_beam_scorer = GroupBeamScorer( batch_size=batch_size, max_length=max_length, num_beams=num_beams, - beam_groups=beam_groups, + num_beam_groups=num_beam_groups, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, @@ -665,7 +677,6 @@ def generate( ) return self.group_beam_search( input_ids, - diversity_penalty, diverse_beam_scorer, logits_processor=logits_processor, max_length=max_length, @@ -1266,7 +1277,6 @@ def beam_sample( def group_beam_search( self, input_ids: torch.LongTensor, - diversity_penalty: float, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, @@ -1282,9 +1292,6 @@ def group_beam_search( input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. - diversity_penalty (:obj:`float`): - This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. beam_scorer (:obj:`BeamScorer`): An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of @@ -1300,7 +1307,7 @@ def group_beam_search( eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. model_kwargs: - Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If + Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: @@ -1343,7 +1350,7 @@ def group_beam_search( ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, - ... beam_groups=3 + ... num_beam_groups=3 ... ) >>> # instantiate logits processors @@ -1364,8 +1371,8 @@ def group_beam_search( batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams - beam_groups = beam_scorer.beam_groups - num_sub_beams = num_beams // beam_groups + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups device = input_ids.device batch_beam_size, cur_len = input_ids.shape @@ -1391,7 +1398,7 @@ def group_beam_search( model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self(**model_inputs, return_dict=True) - for beam_group_idx in range(beam_groups): + for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx @@ -1415,19 +1422,9 @@ def group_beam_search( next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) vocab_size = next_token_scores.shape[-1] - # hamming diversity: penalise using same token in current group which was used in previous groups at - # the same time step - for batch_idx in range(batch_size): - # predicted tokens of last time step of previous groups - previous_group_tokens = recent_tokens[ - batch_idx * num_beams : batch_idx * num_beams + group_start_idx - ] - token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(device) - next_token_scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= ( - diversity_penalty * token_frequency - ) - - next_token_scores = logits_processor(group_input_ids, next_token_scores) + next_token_scores = logits_processor( + group_input_ids, next_token_scores, recent_tokens=recent_tokens, beam_group_idx=beam_group_idx + ) next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( next_token_scores ) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index eac770aad0fb..bc922c7dbf63 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -26,6 +26,7 @@ from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering from transformers.generation_beam_search import BeamSearchScorer, GroupBeamScorer from transformers.generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -61,7 +62,7 @@ def _get_input_ids_and_config(self): return config, input_ids, attention_mask, max_length @staticmethod - def _get_logits_processor_and_kwargs(input_length, eos_token_id): + def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None): process_kwargs = { "min_length": input_length + 1, "bad_words_ids": [[1, 0]], @@ -70,6 +71,13 @@ def _get_logits_processor_and_kwargs(input_length, eos_token_id): } logits_processor = LogitsProcessorList( ( + [ + HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2), + ] + if diversity_penalty is not None + else [] + ) + + ( [ MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), ] @@ -122,14 +130,14 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque "length_penalty": 2.0, "num_beams": 2, "num_return_sequences": num_return_sequences, - "beam_groups": 2, # one beam per group - "diversity_penalty": 1.0, + "num_beam_groups": 2, # one beam per group + "diversity_penalty": 2.0, } beam_scorer = GroupBeamScorer( batch_size=batch_size, max_length=max_length, num_beams=beam_kwargs["num_beams"], - beam_groups=beam_kwargs["beam_groups"], + num_beam_groups=beam_kwargs["num_beam_groups"], device=torch_device, length_penalty=beam_kwargs["length_penalty"], do_early_stopping=beam_kwargs["early_stopping"], @@ -435,7 +443,7 @@ def test_group_beam_search_generate(self): config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id + input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0 ) model = model_class(config).to(torch_device) @@ -469,7 +477,6 @@ def test_group_beam_search_generate(self): with torch.no_grad(): output_ids_group_beam_search = model.group_beam_search( input_ids_clone, - beam_kwargs["diversity_penalty"], beam_scorer, max_length=max_length, attention_mask=attention_mask_clone, @@ -509,7 +516,6 @@ def test_group_beam_search_generate(self): with torch.no_grad(): output_ids_beam_search = model.group_beam_search( input_ids_clone, - beam_kwargs["diversity_penalty"], beam_scorer, max_length=max_length, attention_mask=attention_mask_clone, @@ -638,7 +644,7 @@ def test_diverse_beam_search(self): input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) outputs = bart_model.generate( - input_ids, num_beams=4, num_return_sequences=2, beam_groups=4, diversity_penalty=2.0 + input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0 ) generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) From 6b61b8ea35f4f2de2d60224d7559785f0d214cc0 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 11:24:52 +0000 Subject: [PATCH 23/31] add test --- src/transformers/generation_logits_process.py | 9 ++++-- src/transformers/generation_utils.py | 8 +++--- src/transformers/models/rag/modeling_rag.py | 12 ++++++++ tests/test_generation_logits_process.py | 28 +++++++++++++++++++ 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 7b3925e579d9..98cd7186d704 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -449,20 +449,23 @@ def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, - recent_tokens: torch.LongTensor, + current_tokens: torch.LongTensor, beam_group_idx: int, ) -> torch.FloatTensor: # hamming diversity: penalise using same token in current group which was used in previous groups at # the same time step - batch_size = recent_tokens.shape[0] // self._num_beams + batch_size = current_tokens.shape[0] // self._num_beams group_start_idx = beam_group_idx * self._num_sub_beams group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) group_size = group_end_idx - group_start_idx vocab_size = scores.shape[-1] + if group_start_idx == 0: + return scores + for batch_idx in range(batch_size): # predicted tokens of last time step of previous groups - previous_group_tokens = recent_tokens[ + previous_group_tokens = current_tokens[ batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx ] token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index c14cae57646d..0a5526704499 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1389,7 +1389,7 @@ def group_beam_search( while cur_len < max_length: # predicted tokens in cur_len step - recent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) @@ -1423,7 +1423,7 @@ def group_beam_search( vocab_size = next_token_scores.shape[-1] next_token_scores = logits_processor( - group_input_ids, next_token_scores, recent_tokens=recent_tokens, beam_group_idx=beam_group_idx + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( next_token_scores @@ -1454,7 +1454,7 @@ def group_beam_search( input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - recent_tokens[batch_group_indices] = group_input_ids[:, -1] + current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group @@ -1468,7 +1468,7 @@ def group_beam_search( if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) - input_ids = torch.cat([input_ids, recent_tokens.unsqueeze(-1)], dim=-1) + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 if beam_scorer.is_done: break diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 31de9b3922b0..9280f3f15be4 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1219,6 +1219,8 @@ def generate( early_stopping=None, use_cache=None, num_beams=None, + num_beam_groups=None, + diversity_penalty=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, @@ -1295,6 +1297,13 @@ def generate( should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. num_beams (:obj:`int`, `optional`, defaults to 1): Number of beams for beam search. 1 means no beam search. + num_beam_groups (:obj:`int`, `optional`, defaults to 1): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. `this paper `__ for more details. + diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is + enabled. num_return_sequences(:obj:`int`, `optional`, defaults to 1): The number of independently computed returned sequences for each element in the batch. Note that this is not the value we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate` @@ -1319,6 +1328,7 @@ def generate( # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs num_beams = num_beams if num_beams is not None else self.config.num_beams + num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups max_length = max_length if max_length is not None else self.config.max_length num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences @@ -1405,6 +1415,8 @@ def extend_enc_output(tensor, num_beams=None): eos_token_id=eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, + num_beam_groups=num_beam_groups, + diversity_penalty=diversity_penalty, ) if num_beams == 1: diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index 7dd0d0551785..1aa2941047f4 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from transformers.generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -302,3 +303,30 @@ def prefix_allowed_tokens_fn(batch_id, inputs_ids): self.assertListEqual( torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]] ) + + def test_hamming_diversity(self): + vocab_size = 4 + num_beams = 2 + num_beam_groups = 2 + + scores = self._get_uniform_logits(num_beams, vocab_size) + # batch_idx = 0 -> index batch_idx * num_beam_groups -> idx = 0 * 2 = 0 -> penalises tokens 1 + # batch_idx = 1 -> index batch_idx * num_beam_groups -> idx = 1 * 2 = 2 -> penalises tokens 1 + current_tokens = torch.tensor([0, 3, 1, 2], device=torch_device, dtype=torch.long) + + diversity_logits_processor = HammingDiversityLogitsProcessor( + diversity_penalty=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups + ) + + processed_scores = diversity_logits_processor(None, scores, current_tokens, 1) + + self.assertTrue( + torch.allclose( + processed_scores[0], torch.tensor([-0.7500, 0.2500, 0.2500, 0.2500], device=torch_device), atol=1e-3 + ) + ) + self.assertTrue( + torch.allclose( + processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3 + ) + ) From 68d841c483fdd99871b7b61c7e884344d6355137 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 12:18:46 +0000 Subject: [PATCH 24/31] avoid too much copy paste --- src/transformers/__init__.py | 2 +- src/transformers/generation_beam_search.py | 216 +------ .../generation_beam_search_comp.py | 556 ++++++++++++++++++ src/transformers/generation_utils.py | 10 +- tests/test_generation_utils.py | 6 +- 5 files changed, 577 insertions(+), 213 deletions(-) create mode 100644 src/transformers/generation_beam_search_comp.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 24ec177ead28..a92fb488125b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -294,7 +294,7 @@ TextDataset, TextDatasetForNextSentencePrediction, ) - from .generation_beam_search import BeamScorer, BeamSearchScorer, GroupBeamScorer + from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( LogitsProcessor, LogitsProcessorList, diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 486009012a61..90e799c928ef 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -116,200 +116,6 @@ def finalize( class BeamSearchScorer(BeamScorer): - r""" - :class:`transformers.BeamScorer` implementing standard beam search decoding. - - Adapted in part from `Facebook's XLM beam search code - `__. - - Args: - batch_size (:obj:`int`): - Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. - max_length (:obj:`int`): - The maximum length of the sequence to be generated. - num_beams (:obj:`int`): - Number of beams for beam search. - device (:obj:`torch.device`): - Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of - :obj:`BeamSearchScorer` will be allocated. - length_penalty (:obj:`float`, `optional`, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the - model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer - sequences. - do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. - num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): - The number of beam hypotheses that shall be returned upon calling - :meth:`~transformer.BeamSearchScorer.finalize`. - """ - - def __init__( - self, - batch_size: int, - max_length: int, - num_beams: int, - device: torch.device, - length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[bool] = False, - num_beam_hyps_to_keep: Optional[int] = 1, - ): - self.max_length = max_length - self.num_beams = num_beams - self.device = device - self.length_penalty = length_penalty - self.do_early_stopping = do_early_stopping - self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - - self._is_init = False - self._beam_hyps = [ - BeamHypotheses( - num_beams=self.num_beams, - max_length=self.max_length, - length_penalty=self.length_penalty, - early_stopping=self.do_early_stopping, - ) - for _ in range(batch_size) - ] - self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) - - if not isinstance(num_beams, int) or num_beams <= 1: - raise ValueError( - f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." - ) - - @property - def is_done(self) -> bool: - return self._done.all() - - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - ) -> Tuple[torch.Tensor]: - cur_len = input_ids.shape[-1] - batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // self.num_beams) - - device = input_ids.device - next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) - - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - assert ( - len(beam_hyp) >= self.num_beams - ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) - assert ( - eos_token_id is not None and pad_token_id is not None - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" - # pad the batch - next_beam_scores[batch_idx, :] = 0 - next_beam_tokens[batch_idx, :] = pad_token_id - next_beam_indices[batch_idx, :] = 0 - continue - - # next tokens for this sentence - beam_idx = 0 - for beam_token_rank, (next_token, next_score, next_index) in enumerate( - zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) - ): - batch_beam_idx = batch_idx * self.num_beams + next_index - # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams - if is_beam_token_worse_than_top_num_beams: - continue - beam_hyp.add( - input_ids[batch_beam_idx].clone(), - next_score.item(), - ) - else: - # add next predicted token since it is not eos_token - next_beam_scores[batch_idx, beam_idx] = next_score - next_beam_tokens[batch_idx, beam_idx] = next_token - next_beam_indices[batch_idx, beam_idx] = batch_beam_idx - beam_idx += 1 - - # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.num_beams: - break - - if beam_idx < self.num_beams: - raise ValueError( - f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." - ) - - # Check if we are done so that we can save a pad step if all(done) - self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( - next_scores[batch_idx].max().item(), cur_len - ) - - return UserDict( - { - "next_beam_scores": next_beam_scores.view(-1), - "next_beam_tokens": next_beam_tokens.view(-1), - "next_beam_indices": next_beam_indices.view(-1), - } - ) - - def finalize( - self, - input_ids: torch.LongTensor, - final_beam_scores: torch.FloatTensor, - final_beam_tokens: torch.LongTensor, - final_beam_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - ) -> torch.LongTensor: - batch_size = len(self._beam_hyps) - - # finalize all open beam hypotheses and add to generated hypotheses - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - continue - - # need to add best num_beams hypotheses to generated hyps - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id - final_score = final_beam_scores[batch_beam_idx].item() - final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score) - - # select the best hypotheses - sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) - best = [] - - # retrieve best hypotheses - for i, beam_hyp in enumerate(self._beam_hyps): - sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) - for j in range(self.num_beam_hyps_to_keep): - best_hyp = sorted_hyps.pop()[1] - sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) - best.append(best_hyp) - - # prepare for adding eos - sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) - decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) - # shorter batches are padded if needed - if sent_lengths.min().item() != sent_lengths.max().item(): - assert pad_token_id is not None, "`pad_token_id` has to be defined" - decoded.fill_(pad_token_id) - - # fill with hypotheses and eos_token_id if the latter fits in - for i, hypo in enumerate(best): - decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < self.max_length: - decoded[i, sent_lengths[i]] = eos_token_id - return decoded - - -class GroupBeamScorer(BeamScorer): r""" :class:`transformers.BeamScorer` implementing diverse beam search decoding. @@ -326,9 +132,6 @@ class GroupBeamScorer(BeamScorer): The maximum length of the sequence to be generated. num_beams (:obj:`int`): Number of beams for beam search. - num_beam_groups (:obj:`int`): - Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of - beams. See `this paper `__ for more details. device (:obj:`torch.device`): Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of :obj:`BeamSearchScorer` will be allocated. @@ -341,6 +144,9 @@ class GroupBeamScorer(BeamScorer): num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): The number of beam hypotheses that shall be returned upon calling :meth:`~transformer.BeamSearchScorer.finalize`. + num_beam_groups (:obj:`int`): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. See `this paper `__ for more details. """ def __init__( @@ -348,11 +154,11 @@ def __init__( batch_size: int, max_length: int, num_beams: int, - num_beam_groups: int, device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, num_beam_hyps_to_keep: Optional[int] = 1, + num_beam_groups: Optional[int] = 1, ): self.max_length = max_length self.num_beams = num_beams @@ -380,6 +186,11 @@ def __init__( f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." ) + if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): + raise ValueError( + f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." + ) + @property def is_done(self) -> bool: return self._done.all() @@ -471,19 +282,16 @@ def finalize( eos_token_id: Optional[int] = None, ) -> torch.LongTensor: batch_size = len(self._beam_hyps) - final_beam_scores = final_beam_scores.view((batch_size, self.num_beams)) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue - batch_beam_scores = final_beam_scores[batch_idx, :] - _, beam_ids = torch.sort(batch_beam_scores, descending=True) # need to add best num_beams hypotheses to generated hyps - for beam_id in beam_ids: - batch_beam_idx = batch_idx * self.num_beams + beam_id.item() - final_score = batch_beam_scores[beam_id.item()].item() + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) diff --git a/src/transformers/generation_beam_search_comp.py b/src/transformers/generation_beam_search_comp.py new file mode 100644 index 000000000000..10c14181365e --- /dev/null +++ b/src/transformers/generation_beam_search_comp.py @@ -0,0 +1,556 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import UserDict +from typing import Optional, Tuple + +import torch + +from .file_utils import add_start_docstrings + + +PROCESS_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. + next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. + next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + + Return: + :obj:`UserDict`: A dictionary composed of the fields as defined above: + + - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated + scores of all non-finished beams. + - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens + to be added to the non-finished beam_hypotheses. + - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices + indicating to which beam the next tokens shall be added. + +""" + +FINALIZE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The final scores of all non-finished beams. + final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The last tokens to be added to the non-finished beam_hypotheses. + final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + +""" + + +class BeamScorer(ABC): + """ + Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and + :meth:`~transformers.PretrainedModel.beam_sample`. + """ + + @abstractmethod + @add_start_docstrings(PROCESS_INPUTS_DOCSTRING) + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> Tuple[torch.Tensor]: + raise NotImplementedError("This is an abstract method.") + + @abstractmethod + @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING) + def finalize( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> torch.LongTensor: + raise NotImplementedError("This is an abstract method.") + + +class BeamSearchScorer(BeamScorer): + r""" + :class:`transformers.BeamScorer` implementing standard beam search decoding. + + Adapted in part from `Facebook's XLM beam search code + `__. + + Args: + batch_size (:obj:`int`): + Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + num_beams (:obj:`int`): + Number of beams for beam search. + device (:obj:`torch.device`): + Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of + :obj:`BeamSearchScorer` will be allocated. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + :meth:`~transformer.BeamSearchScorer.finalize`. + """ + + def __init__( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + max_length=self.max_length, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> Tuple[torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + assert batch_size == (input_ids.shape[0] // self.num_beams) + + device = input_ids.device + next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + assert ( + len(beam_hyp) >= self.num_beams + ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.num_beams + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.num_beams: + break + + if beam_idx < self.num_beams: + raise ValueError( + f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> torch.LongTensor: + batch_size = len(self._beam_hyps) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add(final_tokens, final_score) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp = sorted_hyps.pop()[1] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + best.append(best_hyp) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < self.max_length: + decoded[i, sent_lengths[i]] = eos_token_id + return decoded + + +class GroupBeamScorer(BeamScorer): + r""" + :class:`transformers.BeamScorer` implementing standard beam search decoding. + + Adapted in part from `Facebook's XLM beam search code + `__. + + Args: + batch_size (:obj:`int`): + Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + num_beams (:obj:`int`): + Number of beams for beam search. + device (:obj:`torch.device`): + Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of + :obj:`BeamSearchScorer` will be allocated. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + :meth:`~transformer.BeamSearchScorer.finalize`. + num_beam_groups (:obj:`int`, defaults to 1): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. See `this paper `__ for more details. + """ + + def __init__( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + num_beam_groups: Optional[int] = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + self.group_size = self.num_beams // self.num_beam_groups + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + max_length=self.max_length, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> Tuple[torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + assert batch_size == (input_ids.shape[0] // self.group_size) + + device = input_ids.device + next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + assert ( + len(beam_hyp) >= self.num_beams + ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.group_size + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.group_size: + break + + if beam_idx < self.group_size: + raise ValueError( + f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> torch.LongTensor: + batch_size = len(self._beam_hyps) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add(final_tokens, final_score) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp = sorted_hyps.pop()[1] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + best.append(best_hyp) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < self.max_length: + decoded[i, sent_lengths[i]] = eos_token_id + return decoded + + +class BeamHypotheses: + def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp: torch.LongTensor, sum_logprobs: float): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp)) + if len(self) > self.num_beams: + sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) + del self.beams[sorted_next_scores[0][1]] + self.worst_score = sorted_next_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 0a5526704499..2a5e27e138cf 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -20,7 +20,7 @@ from torch.nn import functional as F from .file_utils import ModelOutput -from .generation_beam_search import BeamScorer, BeamSearchScorer, GroupBeamScorer +from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( HammingDiversityLogitsProcessor, LogitsProcessorList, @@ -661,15 +661,15 @@ def generate( if num_beams % num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") - diverse_beam_scorer = GroupBeamScorer( + diverse_beam_scorer = BeamSearchScorer( batch_size=batch_size, max_length=max_length, num_beams=num_beams, - num_beam_groups=num_beam_groups, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=num_beam_groups, ) # interleave with `num_beams` input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -1322,7 +1322,7 @@ def group_beam_search( ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, - ... GroupBeamScorer, + ... BeamSearchScorer, ... ) >>> import torch @@ -1345,7 +1345,7 @@ def group_beam_search( ... } >>> # instantiate beam scorer - >>> beam_scorer = GroupBeamScorer( + >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index bc922c7dbf63..28559fb75426 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -24,7 +24,7 @@ import torch from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering - from transformers.generation_beam_search import BeamSearchScorer, GroupBeamScorer + from transformers.generation_beam_search import BeamSearchScorer from transformers.generation_logits_process import ( HammingDiversityLogitsProcessor, LogitsProcessorList, @@ -133,15 +133,15 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque "num_beam_groups": 2, # one beam per group "diversity_penalty": 2.0, } - beam_scorer = GroupBeamScorer( + beam_scorer = BeamSearchScorer( batch_size=batch_size, max_length=max_length, num_beams=beam_kwargs["num_beams"], - num_beam_groups=beam_kwargs["num_beam_groups"], device=torch_device, length_penalty=beam_kwargs["length_penalty"], do_early_stopping=beam_kwargs["early_stopping"], num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=beam_kwargs["num_beam_groups"], ) return beam_kwargs, beam_scorer From 8751681c8da2f79667cf288035c1907553a89755 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 12:24:42 +0000 Subject: [PATCH 25/31] refactor --- src/transformers/generation_beam_search.py | 4 +- .../generation_beam_search_comp.py | 556 ------------------ src/transformers/generation_utils.py | 3 +- src/transformers/utils/dummy_pt_objects.py | 5 - 4 files changed, 4 insertions(+), 564 deletions(-) delete mode 100644 src/transformers/generation_beam_search_comp.py diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 90e799c928ef..55b7dbc143da 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -117,7 +117,7 @@ def finalize( class BeamSearchScorer(BeamScorer): r""" - :class:`transformers.BeamScorer` implementing diverse beam search decoding. + :class:`transformers.BeamScorer` implementing standard beam search decoding. Adapted in part from `Facebook's XLM beam search code `__. @@ -127,7 +127,7 @@ class BeamSearchScorer(BeamScorer): Args: batch_size (:obj:`int`): - Batch Size of :obj:`input_ids` for which diverse beam search decoding is run in parallel. + Batch Size of :obj:`input_ids` for which standard beam search decoding is run in parallel. max_length (:obj:`int`): The maximum length of the sequence to be generated. num_beams (:obj:`int`): diff --git a/src/transformers/generation_beam_search_comp.py b/src/transformers/generation_beam_search_comp.py deleted file mode 100644 index 10c14181365e..000000000000 --- a/src/transformers/generation_beam_search_comp.py +++ /dev/null @@ -1,556 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from collections import UserDict -from typing import Optional, Tuple - -import torch - -from .file_utils import add_start_docstrings - - -PROCESS_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - details. - - `What are input IDs? <../glossary.html#input-ids>`__ - next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): - Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. - next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): - :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. - next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): - Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. - pad_token_id (:obj:`int`, `optional`): - The id of the `padding` token. - eos_token_id (:obj:`int`, `optional`): - The id of the `end-of-sequence` token. - - Return: - :obj:`UserDict`: A dictionary composed of the fields as defined above: - - - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated - scores of all non-finished beams. - - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens - to be added to the non-finished beam_hypotheses. - - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices - indicating to which beam the next tokens shall be added. - -""" - -FINALIZE_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - details. - - `What are input IDs? <../glossary.html#input-ids>`__ - final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): - The final scores of all non-finished beams. - final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): - The last tokens to be added to the non-finished beam_hypotheses. - final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): - The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. - pad_token_id (:obj:`int`, `optional`): - The id of the `padding` token. - eos_token_id (:obj:`int`, `optional`): - The id of the `end-of-sequence` token. - - Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated - sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all - batches finished early due to the :obj:`eos_token_id`. - -""" - - -class BeamScorer(ABC): - """ - Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and - :meth:`~transformers.PretrainedModel.beam_sample`. - """ - - @abstractmethod - @add_start_docstrings(PROCESS_INPUTS_DOCSTRING) - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - **kwargs - ) -> Tuple[torch.Tensor]: - raise NotImplementedError("This is an abstract method.") - - @abstractmethod - @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING) - def finalize( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - **kwargs - ) -> torch.LongTensor: - raise NotImplementedError("This is an abstract method.") - - -class BeamSearchScorer(BeamScorer): - r""" - :class:`transformers.BeamScorer` implementing standard beam search decoding. - - Adapted in part from `Facebook's XLM beam search code - `__. - - Args: - batch_size (:obj:`int`): - Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. - max_length (:obj:`int`): - The maximum length of the sequence to be generated. - num_beams (:obj:`int`): - Number of beams for beam search. - device (:obj:`torch.device`): - Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of - :obj:`BeamSearchScorer` will be allocated. - length_penalty (:obj:`float`, `optional`, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the - model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer - sequences. - do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. - num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): - The number of beam hypotheses that shall be returned upon calling - :meth:`~transformer.BeamSearchScorer.finalize`. - """ - - def __init__( - self, - batch_size: int, - max_length: int, - num_beams: int, - device: torch.device, - length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[bool] = False, - num_beam_hyps_to_keep: Optional[int] = 1, - ): - self.max_length = max_length - self.num_beams = num_beams - self.device = device - self.length_penalty = length_penalty - self.do_early_stopping = do_early_stopping - self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - - self._is_init = False - self._beam_hyps = [ - BeamHypotheses( - num_beams=self.num_beams, - max_length=self.max_length, - length_penalty=self.length_penalty, - early_stopping=self.do_early_stopping, - ) - for _ in range(batch_size) - ] - self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) - - if not isinstance(num_beams, int) or num_beams <= 1: - raise ValueError( - f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." - ) - - @property - def is_done(self) -> bool: - return self._done.all() - - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - ) -> Tuple[torch.Tensor]: - cur_len = input_ids.shape[-1] - batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // self.num_beams) - - device = input_ids.device - next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) - - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - assert ( - len(beam_hyp) >= self.num_beams - ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) - assert ( - eos_token_id is not None and pad_token_id is not None - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" - # pad the batch - next_beam_scores[batch_idx, :] = 0 - next_beam_tokens[batch_idx, :] = pad_token_id - next_beam_indices[batch_idx, :] = 0 - continue - - # next tokens for this sentence - beam_idx = 0 - for beam_token_rank, (next_token, next_score, next_index) in enumerate( - zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) - ): - batch_beam_idx = batch_idx * self.num_beams + next_index - # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams - if is_beam_token_worse_than_top_num_beams: - continue - beam_hyp.add( - input_ids[batch_beam_idx].clone(), - next_score.item(), - ) - else: - # add next predicted token since it is not eos_token - next_beam_scores[batch_idx, beam_idx] = next_score - next_beam_tokens[batch_idx, beam_idx] = next_token - next_beam_indices[batch_idx, beam_idx] = batch_beam_idx - beam_idx += 1 - - # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.num_beams: - break - - if beam_idx < self.num_beams: - raise ValueError( - f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." - ) - - # Check if we are done so that we can save a pad step if all(done) - self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( - next_scores[batch_idx].max().item(), cur_len - ) - - return UserDict( - { - "next_beam_scores": next_beam_scores.view(-1), - "next_beam_tokens": next_beam_tokens.view(-1), - "next_beam_indices": next_beam_indices.view(-1), - } - ) - - def finalize( - self, - input_ids: torch.LongTensor, - final_beam_scores: torch.FloatTensor, - final_beam_tokens: torch.LongTensor, - final_beam_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - ) -> torch.LongTensor: - batch_size = len(self._beam_hyps) - - # finalize all open beam hypotheses and add to generated hypotheses - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - continue - - # need to add best num_beams hypotheses to generated hyps - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id - final_score = final_beam_scores[batch_beam_idx].item() - final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score) - - # select the best hypotheses - sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) - best = [] - - # retrieve best hypotheses - for i, beam_hyp in enumerate(self._beam_hyps): - sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) - for j in range(self.num_beam_hyps_to_keep): - best_hyp = sorted_hyps.pop()[1] - sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) - best.append(best_hyp) - - # prepare for adding eos - sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) - decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) - # shorter batches are padded if needed - if sent_lengths.min().item() != sent_lengths.max().item(): - assert pad_token_id is not None, "`pad_token_id` has to be defined" - decoded.fill_(pad_token_id) - - # fill with hypotheses and eos_token_id if the latter fits in - for i, hypo in enumerate(best): - decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < self.max_length: - decoded[i, sent_lengths[i]] = eos_token_id - return decoded - - -class GroupBeamScorer(BeamScorer): - r""" - :class:`transformers.BeamScorer` implementing standard beam search decoding. - - Adapted in part from `Facebook's XLM beam search code - `__. - - Args: - batch_size (:obj:`int`): - Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. - max_length (:obj:`int`): - The maximum length of the sequence to be generated. - num_beams (:obj:`int`): - Number of beams for beam search. - device (:obj:`torch.device`): - Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of - :obj:`BeamSearchScorer` will be allocated. - length_penalty (:obj:`float`, `optional`, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the - model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer - sequences. - do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. - num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): - The number of beam hypotheses that shall be returned upon calling - :meth:`~transformer.BeamSearchScorer.finalize`. - num_beam_groups (:obj:`int`, defaults to 1): - Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of - beams. See `this paper `__ for more details. - """ - - def __init__( - self, - batch_size: int, - max_length: int, - num_beams: int, - device: torch.device, - length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[bool] = False, - num_beam_hyps_to_keep: Optional[int] = 1, - num_beam_groups: Optional[int] = 1, - ): - self.max_length = max_length - self.num_beams = num_beams - self.device = device - self.length_penalty = length_penalty - self.do_early_stopping = do_early_stopping - self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - self.group_size = self.num_beams // self.num_beam_groups - - self._is_init = False - self._beam_hyps = [ - BeamHypotheses( - num_beams=self.num_beams, - max_length=self.max_length, - length_penalty=self.length_penalty, - early_stopping=self.do_early_stopping, - ) - for _ in range(batch_size) - ] - self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) - - if not isinstance(num_beams, int) or num_beams <= 1: - raise ValueError( - f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." - ) - - @property - def is_done(self) -> bool: - return self._done.all() - - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - ) -> Tuple[torch.Tensor]: - cur_len = input_ids.shape[-1] - batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // self.group_size) - - device = input_ids.device - next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) - - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - assert ( - len(beam_hyp) >= self.num_beams - ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) - assert ( - eos_token_id is not None and pad_token_id is not None - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" - # pad the batch - next_beam_scores[batch_idx, :] = 0 - next_beam_tokens[batch_idx, :] = pad_token_id - next_beam_indices[batch_idx, :] = 0 - continue - - # next tokens for this sentence - beam_idx = 0 - for beam_token_rank, (next_token, next_score, next_index) in enumerate( - zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) - ): - batch_beam_idx = batch_idx * self.group_size + next_index - # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size - if is_beam_token_worse_than_top_num_beams: - continue - beam_hyp.add( - input_ids[batch_beam_idx].clone(), - next_score.item(), - ) - else: - # add next predicted token since it is not eos_token - next_beam_scores[batch_idx, beam_idx] = next_score - next_beam_tokens[batch_idx, beam_idx] = next_token - next_beam_indices[batch_idx, beam_idx] = batch_beam_idx - beam_idx += 1 - - # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.group_size: - break - - if beam_idx < self.group_size: - raise ValueError( - f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." - ) - - # Check if we are done so that we can save a pad step if all(done) - self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( - next_scores[batch_idx].max().item(), cur_len - ) - - return UserDict( - { - "next_beam_scores": next_beam_scores.view(-1), - "next_beam_tokens": next_beam_tokens.view(-1), - "next_beam_indices": next_beam_indices.view(-1), - } - ) - - def finalize( - self, - input_ids: torch.LongTensor, - final_beam_scores: torch.FloatTensor, - final_beam_tokens: torch.LongTensor, - final_beam_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - ) -> torch.LongTensor: - batch_size = len(self._beam_hyps) - - # finalize all open beam hypotheses and add to generated hypotheses - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - continue - - # need to add best num_beams hypotheses to generated hyps - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id - final_score = final_beam_scores[batch_beam_idx].item() - final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score) - - # select the best hypotheses - sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) - best = [] - - # retrieve best hypotheses - for i, beam_hyp in enumerate(self._beam_hyps): - sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) - for j in range(self.num_beam_hyps_to_keep): - best_hyp = sorted_hyps.pop()[1] - sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) - best.append(best_hyp) - - # prepare for adding eos - sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) - decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) - # shorter batches are padded if needed - if sent_lengths.min().item() != sent_lengths.max().item(): - assert pad_token_id is not None, "`pad_token_id` has to be defined" - decoded.fill_(pad_token_id) - - # fill with hypotheses and eos_token_id if the latter fits in - for i, hypo in enumerate(best): - decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < self.max_length: - decoded[i, sent_lengths[i]] = eos_token_id - return decoded - - -class BeamHypotheses: - def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): - """ - Initialize n-best list of hypotheses. - """ - self.max_length = max_length - 1 # ignoring bos_token - self.length_penalty = length_penalty - self.early_stopping = early_stopping - self.num_beams = num_beams - self.beams = [] - self.worst_score = 1e9 - - def __len__(self): - """ - Number of hypotheses in the list. - """ - return len(self.beams) - - def add(self, hyp: torch.LongTensor, sum_logprobs: float): - """ - Add a new hypothesis to the list. - """ - score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) - if len(self) < self.num_beams or score > self.worst_score: - self.beams.append((score, hyp)) - if len(self) > self.num_beams: - sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) - del self.beams[sorted_next_scores[0][1]] - self.worst_score = sorted_next_scores[1][0] - else: - self.worst_score = min(score, self.worst_score) - - def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: - """ - If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst - one in the heap, then we are done with this sentence. - """ - - if len(self) < self.num_beams: - return False - elif self.early_stopping: - return True - else: - cur_score = best_sum_logprobs / cur_len ** self.length_penalty - ret = self.worst_score >= cur_score - return ret diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 2a5e27e138cf..3dcc3301eead 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1355,10 +1355,11 @@ def group_beam_search( >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList([ + ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ]) - >>> outputs = model.group_beam_search(input_ids, 5.5, beam_scorer, logits_processor=logits_processor, **model_kwargs) + >>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2ebab93b8b83..b0e81bd8cbc1 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -118,11 +118,6 @@ def __init__(self, *args, **kwargs): requires_pytorch(self) -class GroupBeamScorer: - def __init__(self, *args, **kwargs): - requires_pytorch(self) - - class LogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) From a221ef0fb6a109b0be0d42993d0545056618f20d Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 12:30:40 +0000 Subject: [PATCH 26/31] add to docs --- docs/source/internal/generation_utils.rst | 6 ++++++ src/transformers/__init__.py | 2 ++ src/transformers/generation_utils.py | 1 + 3 files changed, 9 insertions(+) diff --git a/docs/source/internal/generation_utils.rst b/docs/source/internal/generation_utils.rst index 9496827a5e16..eaba31060021 100644 --- a/docs/source/internal/generation_utils.rst +++ b/docs/source/internal/generation_utils.rst @@ -40,6 +40,12 @@ generation. .. autoclass:: transformers.NoBadWordsLogitsProcessor :members: __call__ +.. autoclass:: transformers.PrefixConstrainedLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.HammingDiversityLogitsProcessor + :members: __call__ + BeamSearch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a92fb488125b..791ca8d1a530 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -296,12 +296,14 @@ ) from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessor, LogitsProcessorList, LogitsWarper, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 3dcc3301eead..5a9546bf1107 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1322,6 +1322,7 @@ def group_beam_search( ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, + ... HammingDiversityLogitsProcessor, ... BeamSearchScorer, ... ) >>> import torch From 7d8d469793a6453cfaec3a5149d7fcb48166d49f Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 12:34:04 +0000 Subject: [PATCH 27/31] fix-copies --- src/transformers/utils/dummy_pt_objects.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b0e81bd8cbc1..bfb01f24e632 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -118,6 +118,11 @@ def __init__(self, *args, **kwargs): requires_pytorch(self) +class HammingDiversityLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class LogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -148,6 +153,11 @@ def __init__(self, *args, **kwargs): requires_pytorch(self) +class PrefixConstrainedLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class RepetitionPenaltyLogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) From c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Mon, 7 Dec 2020 19:23:39 +0530 Subject: [PATCH 28/31] bug fix --- src/transformers/generation_beam_search.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 55b7dbc143da..36ddc9a8bf02 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -282,16 +282,20 @@ def finalize( eos_token_id: Optional[int] = None, ) -> torch.LongTensor: batch_size = len(self._beam_hyps) + final_beam_scores = final_beam_scores.view((batch_size, self.num_beams)) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue + batch_beam_scores = final_beam_scores[batch_idx, :] + _, beam_ids = torch.sort(batch_beam_scores, descending=True) + # need to add best num_beams hypotheses to generated hyps - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id - final_score = final_beam_scores[batch_beam_idx].item() + for beam_id in beam_ids: + batch_beam_idx = batch_idx * self.num_beams + beam_id.item() + final_score = batch_beam_scores[beam_id.item()].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) From da77f2977e5c7e3b0a5bfc1604c6ebc6653ff6d9 Mon Sep 17 00:00:00 2001 From: Ayush Jain Date: Mon, 7 Dec 2020 19:38:01 +0530 Subject: [PATCH 29/31] Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. --- src/transformers/generation_beam_search.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 36ddc9a8bf02..55b7dbc143da 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -282,20 +282,16 @@ def finalize( eos_token_id: Optional[int] = None, ) -> torch.LongTensor: batch_size = len(self._beam_hyps) - final_beam_scores = final_beam_scores.view((batch_size, self.num_beams)) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue - batch_beam_scores = final_beam_scores[batch_idx, :] - _, beam_ids = torch.sort(batch_beam_scores, descending=True) - # need to add best num_beams hypotheses to generated hyps - for beam_id in beam_ids: - batch_beam_idx = batch_idx * self.num_beams + beam_id.item() - final_score = batch_beam_scores[beam_id.item()].item() + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_hyp.add(final_tokens, final_score) From 757bd3589acbe0399f5440d84c578cdd6f17d1b4 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 14:15:43 +0000 Subject: [PATCH 30/31] improve comment --- src/transformers/generation_beam_search.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 55b7dbc143da..38f144af915b 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -288,7 +288,8 @@ def finalize( if self._done[batch_idx]: continue - # need to add best num_beams hypotheses to generated hyps + # all open beam hypotheses are added to the beam hypothesis + # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() From 81b953382a9b21160480faa68697a9bdee07f75f Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 7 Dec 2020 16:51:52 +0000 Subject: [PATCH 31/31] implement sylvains feedback --- src/transformers/generation_beam_search.py | 3 ++- src/transformers/generation_logits_process.py | 6 +++--- src/transformers/generation_utils.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 38f144af915b..b04c93d5673f 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -188,7 +188,8 @@ def __init__( if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): raise ValueError( - f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." + f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` " + f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) @property diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index ea3e8d2e5c67..8ffab62386cb 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -419,13 +419,13 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): `__ for more details. Args: - diversity_penalty (:obj:`float`, `optional`): + diversity_penalty (:obj:`float`): This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled. - num_beams (:obj:`int`, `optional`): + num_beams (:obj:`int`): Number of beams used for group beam search. See `this paper `__ for more details. - num_beam_groups (:obj:`int`, `optional`): + num_beam_groups (:obj:`int`): Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of beams. See `this paper `__ for more details. """ diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6ddb3ad4e5cf..91cc97c95c1e 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1380,7 +1380,7 @@ def group_beam_search( assert ( num_beams * batch_size == batch_beam_size - ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in