Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,42 +231,52 @@ def _extract_token_timestamps(
tensor containing the timestamps in seconds for each predicted token
"""
# Create a list with `decoder_layers` elements, each a tensor of shape
# (batch size, attention_heads, output length, input length).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape comments were incorrect for the case w/beam search

# (batch size * num beams, attention_heads, output length, input length).
cross_attentions = []
for i in range(self.config.decoder_layers):
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))

# Select specific cross-attention layers and heads. This is a tensor
# of shape (batch size, num selected, output length, input length).
# of shape (batch size * num beams, num selected heads, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])

weight_length = None

if "beam_indices" in generate_outputs:
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this if block, I've rewritten comments to better explain what's happening

# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

root issue of #36093: weight_length is off by 1. The comments in the new version explain why :)


# beam search takes `decoder_input_ids` into account in the `beam_indices` length
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
# we actually shift the beam indices here
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
Comment on lines -253 to -257
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct with the beam search refactor (#35802): beam_indices was corrected to have the same output length as the other optional outputs (= length of generated tokens)

# If beam search was used, the sequence length of the outputs may not be the real sequence length:
# beam search may end up returning a sequence that finished a few steps earlier while decoding.
# In that case, the `cross_attentions` weights are too long and we have to make sure that they have
# the right `output_length`

weights = weights[:, :, :weight_length]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant: we rebuild weights below with sequence length = range(unrolled_beam_indices.shape[1]) (= weight_length)

# get the real sequence length of the longest sequence, crop the beam_indices to the real length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
beam_indices = generate_outputs.beam_indices[:, :weight_length]

# The first forward pass (prefill) may have processed more than one token and, therefore, contain
# cross-attention weights for several tokens.
# Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights.
if num_input_ids is not None and num_input_ids > 1:
# `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1
weight_length += num_input_ids - 1
beam_indices_first_step_unrolled = (
torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long)
* (beam_indices[:, 0:1])
)
unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1)
else:
unrolled_beam_indices = beam_indices

# If beam index is still -1, it means that the associated token id is EOS
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0)

# Select the cross attention from the right beam for each output sequences
# Select the cross attention from the right beam for each output sequence, up to the real sequence
# length (`weight_length`)
weights = torch.stack(
[
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
for i in range(beam_indices.shape[1])
torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i])
for i in range(unrolled_beam_indices.shape[1])
],
dim=2,
)
Expand Down
1 change: 0 additions & 1 deletion tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2176,7 +2176,6 @@ def test_tiny_token_timestamp_batch_generation(self):

# task id and lang id prompts should not have timestamp tokens
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])

self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)

@slow
Expand Down