Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor]`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
group_index (`int`, *optional*):
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].

Return:
`UserDict`: A dictionary composed of the fields as defined above:
Expand Down Expand Up @@ -175,16 +179,22 @@ def __init__(
self.group_size = self.num_beams // self.num_beam_groups

self._is_init = False
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
num_beams=self.group_size,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size)
for _ in range(batch_size * self.num_beam_groups)
]
Comment on lines +182 to 192
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is the most important, and the rest of the changes are for consistency.

self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
# in the i-th mini-batch is complete.
self._done = torch.tensor(
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
)

if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
Expand All @@ -211,9 +221,11 @@ def process(
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps)
batch_size = len(self._beam_hyps) // self.num_beam_groups

if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
Expand All @@ -234,9 +246,10 @@ def process(
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
for batch_idx in range(batch_size):
batch_group_idx = batch_idx * self.num_beam_groups + group_index
if self._done[batch_group_idx]:
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
if eos_token_id is None or pad_token_id is None:
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
Expand Down Expand Up @@ -264,7 +277,7 @@ def process(
else:
beam_index = None

beam_hyp.add(
self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
Expand All @@ -287,7 +300,7 @@ def process(
)

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
)

Expand All @@ -310,20 +323,20 @@ def finalize(
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
batch_size = len(self._beam_hyps) // self.num_beam_groups

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_group_idx]:
continue

# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
for index_per_group in range(self.group_size):
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
Expand All @@ -336,8 +349,10 @@ def finalize(
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for i in range(batch_size):
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3522,10 +3522,10 @@ def group_beam_search(
else self.generation_config.return_dict_in_generate
)

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
device = input_ids.device

batch_beam_size, cur_len = input_ids.shape
Expand Down Expand Up @@ -3648,6 +3648,7 @@ def group_beam_search(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
Expand Down