From 3426db599397159cd13173bf040583c248b8fdc4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 May 2025 13:00:17 +0000 Subject: [PATCH 1/6] tmp --- .../models/whisper/generation_whisper.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index da1d83b2a8b0..9e273149161e 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -244,17 +244,27 @@ def _extract_token_timestamps( weight_length = None if "beam_indices" in generate_outputs: - # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths - # since the beam search strategy chooses the most probable sequences at the end of the search. - # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length + # If beam search has been used, the output sequences may have been generated for more timesteps than their + # `sequence_lengths` since the beam search strategy chooses the most probable sequences at the end of the + # search. In that case, the `cross_attentions` weights are too long and we have to make sure that they have + # the right `output_length` weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() - weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids - # beam search takes `decoder_input_ids` into account in the `beam_indices` length - # but forgot to shift the beam_indices by the number of `decoder_input_ids` - beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length]) - # we actually shift the beam indices here - beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids] + # The first forward pass (prefill) may have processed more than one token and, therefore, contain + # cross-attention weights for several tokens. + # Let's unfold the first `beam_indices` accordingly. + if num_input_ids is not None: + # -1 because the original length would be correct if `num_input_ids` is 1 + weight_length += (num_input_ids - 1) + prepend_beam_indices = torch.ones( + generate_outputs.beam_indices.shape[0], + num_input_ids - 1, + device=generate_outputs.beam_indices.device, + dtype=torch.long, + ) * (generate_outputs.beam_indices[:, 0]) + beam_indices = torch.cat([prepend_beam_indices, generate_outputs.beam_indices], dim=-1) + else: + beam_indices = generate_outputs.beam_indices weights = weights[:, :, :weight_length] From fcd89d45b09800ed27002e10ee8c8c778fef9e89 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 May 2025 13:36:40 +0000 Subject: [PATCH 2/6] fix test_tiny_token_timestamp_batch_generation --- .../models/whisper/generation_whisper.py | 17 ++++++++++------- tests/models/whisper/test_modeling_whisper.py | 4 +--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 9e273149161e..b69ef60f2c33 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -255,13 +255,16 @@ def _extract_token_timestamps( # Let's unfold the first `beam_indices` accordingly. if num_input_ids is not None: # -1 because the original length would be correct if `num_input_ids` is 1 - weight_length += (num_input_ids - 1) - prepend_beam_indices = torch.ones( - generate_outputs.beam_indices.shape[0], - num_input_ids - 1, - device=generate_outputs.beam_indices.device, - dtype=torch.long, - ) * (generate_outputs.beam_indices[:, 0]) + weight_length += num_input_ids - 1 + prepend_beam_indices = ( + torch.ones( + generate_outputs.beam_indices.shape[0], + num_input_ids - 1, + device=generate_outputs.beam_indices.device, + dtype=torch.long, + ) + * (generate_outputs.beam_indices[:, 0:1]) + ) beam_indices = torch.cat([prepend_beam_indices, generate_outputs.beam_indices], dim=-1) else: beam_indices = generate_outputs.beam_indices diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 0446fb2052d4..f97a459b239b 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2174,9 +2174,7 @@ def test_tiny_token_timestamp_batch_generation(self): num_return_sequences=num_return_sequences, ) - # task id and lang id prompts should not have timestamp tokens - self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1]) - + self.assertEqual(generate_outputs["sequences"].shape[-1], generate_outputs["token_timestamps"].shape[-1]) self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples) @slow From 023ad36c66a813966159545e7fccbbdd70094e4c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 May 2025 13:52:05 +0000 Subject: [PATCH 3/6] better comments --- .../models/whisper/generation_whisper.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index b69ef60f2c33..21f76a9bb098 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -244,19 +244,19 @@ def _extract_token_timestamps( weight_length = None if "beam_indices" in generate_outputs: - # If beam search has been used, the output sequences may have been generated for more timesteps than their - # `sequence_lengths` since the beam search strategy chooses the most probable sequences at the end of the - # search. In that case, the `cross_attentions` weights are too long and we have to make sure that they have + # If beam search has been used, the sequence length of the outputs may not be the real sequence length: + # beam search may end up returning a sequence that finished a few steps earlier while decoding. + # In that case, the `cross_attentions` weights are too long and we have to make sure that they have # the right `output_length` - weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() + weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() # real sequence length # The first forward pass (prefill) may have processed more than one token and, therefore, contain # cross-attention weights for several tokens. - # Let's unfold the first `beam_indices` accordingly. - if num_input_ids is not None: - # -1 because the original length would be correct if `num_input_ids` is 1 + # Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights. + if num_input_ids is not None and num_input_ids > 1: + # `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1 weight_length += num_input_ids - 1 - prepend_beam_indices = ( + beam_indices_first_step_unrolled = ( torch.ones( generate_outputs.beam_indices.shape[0], num_input_ids - 1, @@ -265,21 +265,22 @@ def _extract_token_timestamps( ) * (generate_outputs.beam_indices[:, 0:1]) ) - beam_indices = torch.cat([prepend_beam_indices, generate_outputs.beam_indices], dim=-1) + unrolled_beam_indices = torch.cat( + [beam_indices_first_step_unrolled, generate_outputs.beam_indices], dim=-1 + ) else: - beam_indices = generate_outputs.beam_indices - - weights = weights[:, :, :weight_length] + unrolled_beam_indices = generate_outputs.beam_indices # If beam index is still -1, it means that the associated token id is EOS # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. - beam_indices = beam_indices.masked_fill(beam_indices == -1, 0) + unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0) - # Select the cross attention from the right beam for each output sequences + # Select the cross attention from the right beam for each output sequences, up to the real sequence + # length (`weight_length`) weights = torch.stack( [ - torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i]) - for i in range(beam_indices.shape[1]) + torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i]) + for i in range(unrolled_beam_indices.shape[1]) ], dim=2, ) From 4bfc7fd05e08cd1931afed9b33d8783bbb584313 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 May 2025 14:15:50 +0000 Subject: [PATCH 4/6] test --- tests/models/whisper/test_modeling_whisper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f97a459b239b..4133da5c7f98 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2174,7 +2174,8 @@ def test_tiny_token_timestamp_batch_generation(self): num_return_sequences=num_return_sequences, ) - self.assertEqual(generate_outputs["sequences"].shape[-1], generate_outputs["token_timestamps"].shape[-1]) + # task id and lang id prompts should not have timestamp tokens + self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1]) self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples) @slow From a4912b779e0727834b2700e87f2552e4e2ba0ba6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 May 2025 14:30:55 +0000 Subject: [PATCH 5/6] comments --- .../models/whisper/generation_whisper.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 21f76a9bb098..ce8e91436e28 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -231,13 +231,13 @@ def _extract_token_timestamps( tensor containing the timestamps in seconds for each predicted token """ # Create a list with `decoder_layers` elements, each a tensor of shape - # (batch size, attention_heads, output length, input length). + # (batch size * num_beams, attention_heads, output length, input length). cross_attentions = [] for i in range(self.config.decoder_layers): cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2)) # Select specific cross-attention layers and heads. This is a tensor - # of shape (batch size, num selected, output length, input length). + # of shape (batch size * num_beams, num selected, output length, input length). weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) weights = weights.permute([1, 0, 2, 3]) @@ -248,7 +248,10 @@ def _extract_token_timestamps( # beam search may end up returning a sequence that finished a few steps earlier while decoding. # In that case, the `cross_attentions` weights are too long and we have to make sure that they have # the right `output_length` - weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() # real sequence length + + # get the real sequence length of the longest sequence, crop the beam_indices to the real length + weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() + beam_indices = generate_outputs.beam_indices[:, :weight_length] # The first forward pass (prefill) may have processed more than one token and, therefore, contain # cross-attention weights for several tokens. @@ -257,25 +260,18 @@ def _extract_token_timestamps( # `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1 weight_length += num_input_ids - 1 beam_indices_first_step_unrolled = ( - torch.ones( - generate_outputs.beam_indices.shape[0], - num_input_ids - 1, - device=generate_outputs.beam_indices.device, - dtype=torch.long, - ) - * (generate_outputs.beam_indices[:, 0:1]) - ) - unrolled_beam_indices = torch.cat( - [beam_indices_first_step_unrolled, generate_outputs.beam_indices], dim=-1 + torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long) + * (beam_indices[:, 0:1]) ) + unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1) else: - unrolled_beam_indices = generate_outputs.beam_indices + unrolled_beam_indices = beam_indices # If beam index is still -1, it means that the associated token id is EOS # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0) - # Select the cross attention from the right beam for each output sequences, up to the real sequence + # Select the cross attention from the right beam for each output sequence, up to the real sequence # length (`weight_length`) weights = torch.stack( [ From cc32d45147e115e9585e2e249d4c160534f93b2a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 May 2025 10:53:15 +0100 Subject: [PATCH 6/6] Apply suggestions from code review Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --- src/transformers/models/whisper/generation_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index ce8e91436e28..8fadecc240fc 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -231,20 +231,20 @@ def _extract_token_timestamps( tensor containing the timestamps in seconds for each predicted token """ # Create a list with `decoder_layers` elements, each a tensor of shape - # (batch size * num_beams, attention_heads, output length, input length). + # (batch size * num beams, attention_heads, output length, input length). cross_attentions = [] for i in range(self.config.decoder_layers): cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2)) # Select specific cross-attention layers and heads. This is a tensor - # of shape (batch size * num_beams, num selected, output length, input length). + # of shape (batch size * num beams, num selected heads, output length, input length). weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) weights = weights.permute([1, 0, 2, 3]) weight_length = None if "beam_indices" in generate_outputs: - # If beam search has been used, the sequence length of the outputs may not be the real sequence length: + # If beam search was used, the sequence length of the outputs may not be the real sequence length: # beam search may end up returning a sequence that finished a few steps earlier while decoding. # In that case, the `cross_attentions` weights are too long and we have to make sure that they have # the right `output_length`