Skip to content
Merged
27 changes: 25 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def _convert_to_list(token_ids):
return token_ids


def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500):
"""
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
the various options not allowed in other seq2seq models
Expand Down Expand Up @@ -960,6 +960,12 @@ def new_chunk():
last_timestamp = None
first_timestamp = timestamp_begin

# long form generation: we need to handle the case where the call to generate returns concatenated segments,
# with underlying multiple calls to generate
cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0

if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Offset the timings to account for the other `model_outputs`.
Expand Down Expand Up @@ -1022,7 +1028,24 @@ def new_chunk():
pass
elif token >= timestamp_begin:
# 3/ Timestamp token
time = (token - timestamp_begin) * time_precision + time_offset

timestamp = float((token - timestamp_begin) * time_precision)
if timestamp < cur_max_timestamp:
# next segment has started
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp

time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len

time = round(time, 2)
if last_timestamp and token >= last_timestamp:
# Whisper outputted a timestamp token, but it falls within
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ def _sanitize_parameters(
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
elif self.type == "seq2seq_whisper" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` with Whisper models is not recommended and will result in unreliable results, as it uses it's own chunking mechanism "
"(cf. Whisper original paper, section 3.8. Long-form Transcription)."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned offline would be a pity to not use that batch algo in some cases! But up to debate!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True! I just want to make sure:

  1. the user knows that for seq2seq models the pipeline's chunking mechanism is unreliable. A warning already exists for that, it is just not taking whisper into account...
  2. ensure the user is not using the pipeline to do long-form transcription (or at least he knows he could use something more reliable when it comes to whisper) !!

I've updated the warning accordingly

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused. Can we rely on the pipeline to produce accurate word-level timestamps for long-form transcription with Whisper models? If not, what are the potential pitfalls or failure modes? The merge request “Token-level timestamps for long-form generation in Whisper”
doesn’t seem to discuss this issue in much detail.

)
preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s
Expand Down