diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index 792b2a17f5d6..71a459c06852 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -43,6 +43,10 @@ The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + beam_indices (`torch.LongTensor]`, *optional*): + Beam indices indicating to which beam hypothesis each token correspond. + group_index (`int`, *optional*): + The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`]. Return: `UserDict`: A dictionary composed of the fields as defined above: @@ -175,16 +179,22 @@ def __init__( self.group_size = self.num_beams // self.num_beam_groups self._is_init = False + # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch. + # If group_beam_search is not used, the list consists of `batch_size` beam_hyps. self._beam_hyps = [ BeamHypotheses( - num_beams=self.num_beams, + num_beams=self.group_size, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, max_length=max_length, ) - for _ in range(batch_size) + for _ in range(batch_size * self.num_beam_groups) ] - self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group + # in the i-th mini-batch is complete. + self._done = torch.tensor( + [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device + ) if not isinstance(num_beams, int) or num_beams <= 1: raise ValueError( @@ -211,9 +221,11 @@ def process( pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, + group_index: Optional[int] = 0, ) -> Dict[str, torch.Tensor]: cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on - batch_size = len(self._beam_hyps) + batch_size = len(self._beam_hyps) // self.num_beam_groups + if not (batch_size == (input_ids.shape[0] // self.group_size)): if self.num_beam_groups > 1: raise ValueError( @@ -234,9 +246,10 @@ def process( if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - if self.num_beams < len(beam_hyp): + for batch_idx in range(batch_size): + batch_group_idx = batch_idx * self.num_beam_groups + group_index + if self._done[batch_group_idx]: + if self.num_beams < len(self._beam_hyps[batch_group_idx]): raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") if eos_token_id is None or pad_token_id is None: raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") @@ -264,7 +277,7 @@ def process( else: beam_index = None - beam_hyp.add( + self._beam_hyps[batch_group_idx].add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, @@ -287,7 +300,7 @@ def process( ) # Check if we are done so that we can save a pad step if all(done) - self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done( next_scores[batch_idx].max().item(), cur_len ) @@ -310,20 +323,20 @@ def finalize( eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: - batch_size = len(self._beam_hyps) + batch_size = len(self._beam_hyps) // self.num_beam_groups if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] # finalize all open beam hypotheses and add to generated hypotheses - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: + for batch_group_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_group_idx]: continue # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id + for index_per_group in range(self.group_size): + batch_beam_idx = batch_group_idx * self.group_size + index_per_group final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None @@ -336,8 +349,10 @@ def finalize( best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) # retrieve best hypotheses - for i, beam_hyp in enumerate(self._beam_hyps): - sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for i in range(batch_size): + beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups] + candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams] + sorted_hyps = sorted(candidate_beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e5da7a143b4c..b50689ba3f96 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3522,10 +3522,10 @@ def group_beam_search( else self.generation_config.return_dict_in_generate ) - batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups device = input_ids.device batch_beam_size, cur_len = input_ids.shape @@ -3648,6 +3648,7 @@ def group_beam_search( pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=process_beam_indices, + group_index=beam_group_idx, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"]