-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[Whisper + beam search] fix usage of beam_indices
#38259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -231,42 +231,52 @@ def _extract_token_timestamps( | |
| tensor containing the timestamps in seconds for each predicted token | ||
| """ | ||
| # Create a list with `decoder_layers` elements, each a tensor of shape | ||
| # (batch size, attention_heads, output length, input length). | ||
| # (batch size * num beams, attention_heads, output length, input length). | ||
| cross_attentions = [] | ||
| for i in range(self.config.decoder_layers): | ||
| cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2)) | ||
|
|
||
| # Select specific cross-attention layers and heads. This is a tensor | ||
| # of shape (batch size, num selected, output length, input length). | ||
| # of shape (batch size * num beams, num selected heads, output length, input length). | ||
| weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) | ||
| weights = weights.permute([1, 0, 2, 3]) | ||
|
|
||
| weight_length = None | ||
|
|
||
| if "beam_indices" in generate_outputs: | ||
| # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this |
||
| # since the beam search strategy chooses the most probable sequences at the end of the search. | ||
| # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length | ||
| weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() | ||
| weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. root issue of #36093: |
||
|
|
||
| # beam search takes `decoder_input_ids` into account in the `beam_indices` length | ||
| # but forgot to shift the beam_indices by the number of `decoder_input_ids` | ||
| beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length]) | ||
| # we actually shift the beam indices here | ||
| beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids] | ||
|
Comment on lines
-253
to
-257
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not correct with the beam search refactor (#35802): |
||
| # If beam search was used, the sequence length of the outputs may not be the real sequence length: | ||
| # beam search may end up returning a sequence that finished a few steps earlier while decoding. | ||
| # In that case, the `cross_attentions` weights are too long and we have to make sure that they have | ||
| # the right `output_length` | ||
|
|
||
| weights = weights[:, :, :weight_length] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant: we rebuild |
||
| # get the real sequence length of the longest sequence, crop the beam_indices to the real length | ||
| weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() | ||
| beam_indices = generate_outputs.beam_indices[:, :weight_length] | ||
|
|
||
| # The first forward pass (prefill) may have processed more than one token and, therefore, contain | ||
| # cross-attention weights for several tokens. | ||
| # Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights. | ||
| if num_input_ids is not None and num_input_ids > 1: | ||
| # `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1 | ||
| weight_length += num_input_ids - 1 | ||
| beam_indices_first_step_unrolled = ( | ||
| torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long) | ||
| * (beam_indices[:, 0:1]) | ||
| ) | ||
| unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1) | ||
| else: | ||
| unrolled_beam_indices = beam_indices | ||
|
|
||
| # If beam index is still -1, it means that the associated token id is EOS | ||
| # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. | ||
| beam_indices = beam_indices.masked_fill(beam_indices == -1, 0) | ||
| unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0) | ||
|
|
||
| # Select the cross attention from the right beam for each output sequences | ||
| # Select the cross attention from the right beam for each output sequence, up to the real sequence | ||
| # length (`weight_length`) | ||
| weights = torch.stack( | ||
| [ | ||
| torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i]) | ||
| for i in range(beam_indices.shape[1]) | ||
| torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i]) | ||
| for i in range(unrolled_beam_indices.shape[1]) | ||
| ], | ||
| dim=2, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape comments were incorrect for the case w/beam search