Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def forward(

causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
causal_mask = attention_mask[:, : key_states.shape[-2]]

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
Expand Down
25 changes: 20 additions & 5 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def _pad_to_max_length(


class WhisperGenerationMixin(GenerationMixin):
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
def _extract_token_timestamps(
self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None
):
"""
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
Expand All @@ -200,11 +202,18 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids

# beam search takes `decoder_input_ids` into account in the `beam_indices` length
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
# we actually shif the beam indices here
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]

weights = weights[:, :, :weight_length]

# If beam index is still -1, it means that the associated token id is EOS
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
beam_indices = generate_outputs.beam_indices[:, :weight_length]
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)

# Select the cross attention from the right beam for each output sequences
Expand All @@ -218,8 +227,10 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec

# make sure timestamps are as long as weights
input_length = weight_length or cross_attentions[0].shape[2]
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
batch_size = timestamps.shape[0]
batch_size = generate_outputs.sequences.shape[0]
timestamps = torch.zeros(
(batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device
)

if num_frames is not None:
# two cases:
Expand All @@ -239,6 +250,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
else:
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
num_frames = np.repeat(num_frames, repeat_time)

if num_frames is None or isinstance(num_frames, int):
Expand Down Expand Up @@ -948,7 +960,10 @@ def _postprocess_outputs(
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
seek_outputs,
generation_config.alignment_heads,
num_frames=num_frames,
num_input_ids=decoder_input_ids.shape[-1],
)
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def forward(

causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
causal_mask = attention_mask[:, : key_states.shape[-2]]

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
Expand Down
22 changes: 11 additions & 11 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,14 +1960,14 @@ def test_large_generation_multilingual(self):
input_features, do_sample=False, max_length=20, language="<|de|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mein sechster Sohn scheint, wenigstens auf den ersten Blick,"
EXPECTED_TRANSCRIPT = " Denken Sie, soeben walten meine Gedanken bei Ihnen in Adela"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|de|>", task="translate"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " My sixth son seems, at least at first glance, the most deeply-minded"
EXPECTED_TRANSCRIPT = " Think, my thoughts were just rolling with you in Adelaide, and I"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
Expand Down Expand Up @@ -2282,7 +2282,7 @@ def test_tiny_token_timestamp_generation(self):
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)

self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)

# fmt: off
EXPECTED_OUTPUT = torch.tensor([
Expand All @@ -2293,7 +2293,7 @@ def test_tiny_token_timestamp_generation(self):
])
# fmt: on

self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))

@slow
def test_large_token_timestamp_generation(self):
Expand All @@ -2312,7 +2312,7 @@ def test_large_token_timestamp_generation(self):
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)

self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)

# fmt: off
EXPECTED_OUTPUT = torch.tensor([
Expand All @@ -2323,7 +2323,7 @@ def test_large_token_timestamp_generation(self):
])
# fmt: on

self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))

@slow
def test_tiny_token_timestamp_batch_generation(self):
Expand All @@ -2350,9 +2350,9 @@ def test_tiny_token_timestamp_batch_generation(self):
)

# task id and lang id prompts should not have timestamp tokens
self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1])
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])

self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)

@slow
def test_tiny_token_timestamp_generation_longform(self):
Expand Down Expand Up @@ -2843,7 +2843,7 @@ def test_whisper_shortform_single_batch_prev_cond(self):

torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)

assert decoded == EXPECTED_TEXT

Expand All @@ -2858,7 +2858,7 @@ def test_whisper_shortform_single_batch_prev_cond(self):

torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)

assert decoded == EXPECTED_TEXT1

Expand Down Expand Up @@ -3158,7 +3158,7 @@ def test_whisper_shortform_multi_batch_hard_prev_cond(self):
}

result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)

for i in range(num_samples):
if isinstance(EXPECTED_TEXT[i], str):
Expand Down