diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index a5ac1f836385..6422baac5feb 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -291,7 +291,7 @@ def forward( causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + causal_mask = attention_mask[:, : key_states.shape[-2]] # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 7a4e9487288e..32e54e0a121d 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -173,7 +173,9 @@ def _pad_to_max_length( class WhisperGenerationMixin(GenerationMixin): - def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): + def _extract_token_timestamps( + self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None + ): """ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder @@ -200,11 +202,18 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec # since the beam search strategy chooses the most probable sequences at the end of the search. # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() + weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids + + # beam search takes `decoder_input_ids` into account in the `beam_indices` length + # but forgot to shift the beam_indices by the number of `decoder_input_ids` + beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length]) + # we actually shif the beam indices here + beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids] + weights = weights[:, :, :weight_length] # If beam index is still -1, it means that the associated token id is EOS # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. - beam_indices = generate_outputs.beam_indices[:, :weight_length] beam_indices = beam_indices.masked_fill(beam_indices == -1, 0) # Select the cross attention from the right beam for each output sequences @@ -218,8 +227,10 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec # make sure timestamps are as long as weights input_length = weight_length or cross_attentions[0].shape[2] - timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1] - batch_size = timestamps.shape[0] + batch_size = generate_outputs.sequences.shape[0] + timestamps = torch.zeros( + (batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device + ) if num_frames is not None: # two cases: @@ -239,6 +250,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec else: # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames) + num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames num_frames = np.repeat(num_frames, repeat_time) if num_frames is None or isinstance(num_frames, int): @@ -948,7 +960,10 @@ def _postprocess_outputs( if return_token_timestamps and hasattr(generation_config, "alignment_heads"): num_frames = getattr(generation_config, "num_frames", None) seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames + seek_outputs, + generation_config.alignment_heads, + num_frames=num_frames, + num_input_ids=decoder_input_ids.shape[-1], ) seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:] diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b10fc258c8ef..4a38ad0a5e77 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -422,7 +422,7 @@ def forward( causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + causal_mask = attention_mask[:, : key_states.shape[-2]] # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index b4e71ca72e56..f36d02175bcd 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1960,14 +1960,14 @@ def test_large_generation_multilingual(self): input_features, do_sample=False, max_length=20, language="<|de|>", task="transcribe" ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - EXPECTED_TRANSCRIPT = " Mein sechster Sohn scheint, wenigstens auf den ersten Blick," + EXPECTED_TRANSCRIPT = " Denken Sie, soeben walten meine Gedanken bei Ihnen in Adela" self.assertEqual(transcript, EXPECTED_TRANSCRIPT) generated_ids = model.generate( input_features, do_sample=False, max_length=20, language="<|de|>", task="translate" ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - EXPECTED_TRANSCRIPT = " My sixth son seems, at least at first glance, the most deeply-minded" + EXPECTED_TRANSCRIPT = " Think, my thoughts were just rolling with you in Adelaide, and I" self.assertEqual(transcript, EXPECTED_TRANSCRIPT) @slow @@ -2282,7 +2282,7 @@ def test_tiny_token_timestamp_generation(self): input_features, max_length=448, return_timestamps=True, return_token_timestamps=True ) - self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape) + self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape) # fmt: off EXPECTED_OUTPUT = torch.tensor([ @@ -2293,7 +2293,7 @@ def test_tiny_token_timestamp_generation(self): ]) # fmt: on - self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT)) @slow def test_large_token_timestamp_generation(self): @@ -2312,7 +2312,7 @@ def test_large_token_timestamp_generation(self): **input_features, max_length=448, return_timestamps=True, return_token_timestamps=True ) - self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape) + self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape) # fmt: off EXPECTED_OUTPUT = torch.tensor([ @@ -2323,7 +2323,7 @@ def test_large_token_timestamp_generation(self): ]) # fmt: on - self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT)) @slow def test_tiny_token_timestamp_batch_generation(self): @@ -2350,9 +2350,9 @@ def test_tiny_token_timestamp_batch_generation(self): ) # task id and lang id prompts should not have timestamp tokens - self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1]) + self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1]) - self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples) + self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples) @slow def test_tiny_token_timestamp_generation_longform(self): @@ -2843,7 +2843,7 @@ def test_whisper_shortform_single_batch_prev_cond(self): torch.manual_seed(0) result = model.generate(input_features, **gen_kwargs) - decoded = processor.batch_decode(result.sequences, skip_special_tokens=True) + decoded = processor.batch_decode(result, skip_special_tokens=True) assert decoded == EXPECTED_TEXT @@ -2858,7 +2858,7 @@ def test_whisper_shortform_single_batch_prev_cond(self): torch.manual_seed(0) result = model.generate(input_features, **gen_kwargs) - decoded = processor.batch_decode(result.sequences, skip_special_tokens=True) + decoded = processor.batch_decode(result, skip_special_tokens=True) assert decoded == EXPECTED_TEXT1 @@ -3158,7 +3158,7 @@ def test_whisper_shortform_multi_batch_hard_prev_cond(self): } result = model.generate(**inputs, **gen_kwargs) - decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True) + decoded_all = processor.batch_decode(result, skip_special_tokens=True) for i in range(num_samples): if isinstance(EXPECTED_TEXT[i], str):