From 1f0f005b421261a988808c167f3331d4f59d42d7 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Fri, 17 Jan 2025 11:48:47 +0100 Subject: [PATCH 1/7] handle long form generation --- .../models/whisper/tokenization_whisper.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 7983799ad8a7..25c78b56950f 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -908,7 +908,7 @@ def _convert_to_list(token_ids): return token_ids -def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision): +def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500): """ Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle the various options not allowed in other seq2seq models @@ -960,6 +960,12 @@ def new_chunk(): last_timestamp = None first_timestamp = timestamp_begin + # long form generation: we need to handle the case where the call to generate returns concatenated segments, + # with underlying multiple calls to generate + cur_max_timestamp = 0.0 + prev_segments_len = 0.0 + penultimate_timestamp = 0.0 + if "stride" in output: chunk_len, stride_left, stride_right = output["stride"] # Offset the timings to account for the other `model_outputs`. @@ -1022,7 +1028,24 @@ def new_chunk(): pass elif token >= timestamp_begin: # 3/ Timestamp token - time = (token - timestamp_begin) * time_precision + time_offset + + timestamp = float((token - timestamp_begin) * time_precision) + if timestamp < cur_max_timestamp: + # next segment has started + last_was_single_ending = i >= 2 and not ( + token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin + ) + if last_was_single_ending: + prev_segments_len += time_precision * segment_size + else: + cur_max_timestamp = penultimate_timestamp + prev_segments_len += penultimate_timestamp + + penultimate_timestamp = cur_max_timestamp + cur_max_timestamp = timestamp + + time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len + time = round(time, 2) if last_timestamp and token >= last_timestamp: # Whisper outputted a timestamp token, but it falls within From 559ed1372fa9d1a6c8d386dbe40702acc2f80479 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Fri, 17 Jan 2025 14:44:52 +0100 Subject: [PATCH 2/7] add warning --- src/transformers/pipelines/automatic_speech_recognition.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 66a9c49ea5f3..d37d0b9cbeee 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -303,6 +303,11 @@ def _sanitize_parameters( " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," " ignore_warning=True)" ) + elif self.type == "seq2seq_whisper" and not ignore_warning: + logger.warning( + "Using `chunk_length_s` with Whisper models is not recommended and will result in unreliable results, as it uses it's own chunking mechanism " + "(cf. Whisper original paper, section 3.8. Long-form Transcription)." + ) preprocess_params["chunk_length_s"] = chunk_length_s if stride_length_s is not None: preprocess_params["stride_length_s"] = stride_length_s From fbadefe555a5eee33827b7c14ce72ef86105aeb4 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 12 Mar 2025 15:09:54 +0100 Subject: [PATCH 3/7] correct incorrect in place token change --- .../models/whisper/generation_whisper.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1b4ecb831bf9..273a6640bd9e 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -135,6 +135,8 @@ def _pad_to_max_length( cut_off_length=None, return_token_timestamps=False, force_unique_generate_call=False, + skip_ending_double_timestamps=False, + timestamp_begin=None, ): max_total_length = 0 sequences = [] @@ -165,7 +167,17 @@ def _pad_to_max_length( for current_segment_list in current_segments: if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: - sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) + sequences_list = [] + for d in current_segment_list: + if skip_ending_double_timestamps and len(d["tokens"]) > 2 and d["tokens"][-2] >= timestamp_begin: + # the segment finishes with two timestamp tokens + # we need to ignore the last timestamp token + # see https://github.com/huggingface/transformers/pull/34537 + sequences_list.append(d["tokens"][:-1]) + else: + sequences_list.append(d["tokens"]) + sequence = torch.cat(sequences_list, dim=-1) + if return_token_timestamps: token_timestamps = torch.cat( [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list], @@ -1785,14 +1797,6 @@ def _prepare_decoder_input_ids( # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] - for segments in active_segments: - for seg in segments: - if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin: - # the segment finishes with two timestamp tokens - # we need to ignore the last timestamp token - # see https://github.com/huggingface/transformers/pull/34537 - seg["tokens"] = seg["tokens"][:-1] - if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments": prev_ids = prompt_ids else: @@ -1809,6 +1813,8 @@ def _prepare_decoder_input_ids( padding=padding, bos_token_tensor=prev_ids, cut_off_length=cut_off_length, + skip_ending_double_timestamps=True, + timestamp_begin=timestamp_begin, ) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) From 41dd387de963f21437d34917e50c5bbf19ecf1d1 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 12 Mar 2025 15:58:05 +0100 Subject: [PATCH 4/7] update test to catch edge case --- tests/models/whisper/test_modeling_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index ce30ea4eaeb0..695d8b5bca01 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2214,11 +2214,11 @@ def test_large_timestamp_generation(self): ).input_features input_features = input_features.to(torch_device) - generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu") + generated_ids = model.generate(input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True).to("cpu") # fmt: off EXPECTED_OUTPUT = torch.tensor([ - 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430 + [50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50431] ]) # fmt: on torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT) @@ -2261,7 +2261,7 @@ def test_large_timestamp_generation(self): }, { "text": (" and can discover"), - "timestamp": (28.68, 29.98), + "timestamp": (28.68, 30.0), }, ], } From 15071f4efcabfd2b8a129a01459c839930b55493 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 12 Mar 2025 16:04:02 +0100 Subject: [PATCH 5/7] make style --- src/transformers/models/whisper/generation_whisper.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index f6556fa82b8d..32959863d700 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -175,7 +175,7 @@ def _pad_to_max_length( # see https://github.com/huggingface/transformers/pull/34537 sequences_list.append(d["tokens"][:-1]) else: - sequences_list.append(d["tokens"]) + sequences_list.append(d["tokens"]) sequence = torch.cat(sequences_list, dim=-1) if return_token_timestamps: diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 8f941c495bf6..c083b68f0852 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2049,7 +2049,9 @@ def test_large_timestamp_generation(self): ).input_features input_features = input_features.to(torch_device) - generated_ids = model.generate(input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True).to("cpu") + generated_ids = model.generate( + input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True + ).to("cpu") # fmt: off EXPECTED_OUTPUT = torch.tensor([ From 954c36830e105c3922594440fb30843c9e0092c9 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 12 Mar 2025 16:16:30 +0100 Subject: [PATCH 6/7] update warning --- .../pipelines/automatic_speech_recognition.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index d37d0b9cbeee..46e39c1c8aff 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -296,18 +296,20 @@ def _sanitize_parameters( # No parameters on this pipeline right now preprocess_params = {} if chunk_length_s is not None: - if self.type == "seq2seq" and not ignore_warning: - logger.warning( + if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning: + type_warning = ( "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" " be entirely accurate and will have caveats. More information:" " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," - " ignore_warning=True)" - ) - elif self.type == "seq2seq_whisper" and not ignore_warning: - logger.warning( - "Using `chunk_length_s` with Whisper models is not recommended and will result in unreliable results, as it uses it's own chunking mechanism " - "(cf. Whisper original paper, section 3.8. Long-form Transcription)." + " ignore_warning=True)." ) + if self.type == "seq2seq_whisper": + type_warning += ( + " To use Whisper for long-form transcription, use rather the model's `generate` method directly " + "as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. " + "Long-form Transcription)." + ) + logger.warning(type_warning) preprocess_params["chunk_length_s"] = chunk_length_s if stride_length_s is not None: preprocess_params["stride_length_s"] = stride_length_s From 92dd435246896dc240e2259537b99cecfc2fbee0 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 26 Jun 2025 16:19:41 +0200 Subject: [PATCH 7/7] add doc --- src/transformers/models/whisper/generation_whisper.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index ee6a1d03f9a7..2a64e599d065 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -139,6 +139,15 @@ def _pad_to_max_length( skip_ending_double_timestamps=False, timestamp_begin=None, ): + """ + skip_ending_double_timestamps: when the segement ended with two timestamp tokens, whether to ignore the last timestamp token + see https://github.com/huggingface/transformers/pull/35750 + + _pad_to_max_length is used in different contexts: + 1. At the end of generation: we need to keep both ending timestamp tokens in the segment (see https://github.com/huggingface/transformers/pull/34537). + 2. In the middle of generation, e.g. when condition_on_prev_tokens=True and we want to use the last generated tokens as decoder_input_ids: + we must skip one of the double ending timestamp tokens (see https://github.com/huggingface/transformers/pull/35750). + """ max_total_length = 0 sequences = [] token_timestamps_list = []