Skip to content
Merged
24 changes: 15 additions & 9 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def _pad_to_max_length(
cut_off_length=None,
return_token_timestamps=False,
force_unique_generate_call=False,
skip_ending_double_timestamps=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are missing documentation on this one no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment explaining this hidden parameters and links to related PRs to understand why we need it

timestamp_begin=None,
):
max_total_length = 0
sequences = []
Expand Down Expand Up @@ -166,7 +168,17 @@ def _pad_to_max_length(

for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
sequences_list = []
for d in current_segment_list:
if skip_ending_double_timestamps and len(d["tokens"]) > 2 and d["tokens"][-2] >= timestamp_begin:
# the segment finishes with two timestamp tokens
# we need to ignore the last timestamp token
# see https://github.com/huggingface/transformers/pull/34537
sequences_list.append(d["tokens"][:-1])
else:
sequences_list.append(d["tokens"])
sequence = torch.cat(sequences_list, dim=-1)

if return_token_timestamps:
token_timestamps = torch.cat(
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
Expand Down Expand Up @@ -1809,14 +1821,6 @@ def _prepare_decoder_input_ids(
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]

for segments in active_segments:
for seg in segments:
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
# the segment finishes with two timestamp tokens
# we need to ignore the last timestamp token
# see https://github.com/huggingface/transformers/pull/34537
seg["tokens"] = seg["tokens"][:-1]

if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
prev_ids = prompt_ids
else:
Expand All @@ -1833,6 +1837,8 @@ def _prepare_decoder_input_ids(
padding=padding,
bos_token_tensor=prev_ids,
cut_off_length=cut_off_length,
skip_ending_double_timestamps=True,
timestamp_begin=timestamp_begin,
)
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)

Expand Down
27 changes: 25 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ def _convert_to_list(token_ids):
return token_ids


def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500):
"""
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
the various options not allowed in other seq2seq models
Expand Down Expand Up @@ -962,6 +962,12 @@ def new_chunk():
last_timestamp = None
first_timestamp = timestamp_begin

# long form generation: we need to handle the case where the call to generate returns concatenated segments,
# with underlying multiple calls to generate
cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0

if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Offset the timings to account for the other `model_outputs`.
Expand Down Expand Up @@ -1024,7 +1030,24 @@ def new_chunk():
pass
elif token >= timestamp_begin:
# 3/ Timestamp token
time = (token - timestamp_begin) * time_precision + time_offset

timestamp = float((token - timestamp_begin) * time_precision)
if timestamp < cur_max_timestamp:
# next segment has started
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp

time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len

time = round(time, 2)
if last_timestamp and token >= last_timestamp:
# Whisper outputted a timestamp token, but it falls within
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,20 @@ def _sanitize_parameters(
# No parameters on this pipeline right now
preprocess_params = {}
if chunk_length_s is not None:
if self.type == "seq2seq" and not ignore_warning:
logger.warning(
if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning:
type_warning = (
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
" ignore_warning=True)."
)
if self.type == "seq2seq_whisper":
type_warning += (
" To use Whisper for long-form transcription, use rather the model's `generate` method directly "
"as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. "
"Long-form Transcription)."
)
logger.warning(type_warning)
preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s
Expand Down
8 changes: 5 additions & 3 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,11 +2031,13 @@ def test_large_timestamp_generation(self):
).input_features
input_features = input_features.to(torch_device)

generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
generated_ids = model.generate(
input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True
).to("cpu")

# fmt: off
EXPECTED_OUTPUT = torch.tensor([
50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430
[50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50431]
])
# fmt: on
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
Expand Down Expand Up @@ -2078,7 +2080,7 @@ def test_large_timestamp_generation(self):
},
{
"text": (" and can discover"),
"timestamp": (28.68, 29.98),
"timestamp": (28.68, 30.0),
},
],
}
Expand Down