Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,13 +852,22 @@ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
return forced_decoder_ids

def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
return _decode_asr(
self,
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
if return_timestamps:
return _decode_asr_segments(
self,
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
else:
return _decode_asr(
self,
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)

def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
Expand Down Expand Up @@ -908,6 +917,33 @@ def _convert_to_list(token_ids):
return token_ids


def _decode_asr_segments(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""
Decode segments in the model output to obtain chunks of text with their timestamp ranges
"""
timestamp_begin_id = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
prompt_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
all_special_ids = set(tokenizer.all_special_ids)
segments = model_outputs[0]["segments"][0]
chunks = []
for segment in segments:
chunk = {}
token_ids = segment["tokens"].tolist()
token_ids = tokenizer._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
start_time = segment["start"].item()
end_time = segment["end"].item()
# ignore special tokens
token_ids = [token for token in token_ids if (token < timestamp_begin_id) and (token not in all_special_ids)]
chunk["text"] = tokenizer.decode(token_ids)
chunk["timestamp"] = (start_time, end_time)
chunks.append(chunk)

full_text = "".join(chunk["text"] for chunk in chunks)
optional = {"chunks": chunks}
return full_text, optional


def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
# custom processing for Whisper timestamps and word-level timestamps
if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
generate_kwargs["return_segments"] = True # Use segments to return timestamps
if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True
generate_kwargs["return_segments"] = True
Expand All @@ -524,7 +525,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
**generate_kwargs,
)
# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper":
if self.type == "seq2seq_whisper" and return_timestamps == "word":
if "segments" not in tokens:
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
Expand All @@ -533,11 +534,13 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
for segment_list in tokens["segments"]
]
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
else:
elif self.type == "seq2seq_whisper" and return_timestamps:
out = {"tokens": tokens["sequences"], "segments": tokens["segments"]}
else: # seq2seq or whisper without return_timestamps
out = {"tokens": tokens}
if self.type == "seq2seq_whisper":
if stride is not None:
out["stride"] = stride

if self.type == "seq2seq_whisper" and stride is not None:
out["stride"] = stride

else:
inputs = {
Expand Down