From be7c3b8588bf7fb6346c93d5e705d1d9c9440046 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 20 Sep 2023 16:25:34 +0100 Subject: [PATCH] make decoding faster --- .../models/whisper/tokenization_whisper.py | 31 ++++++++--------- .../whisper/tokenization_whisper_fast.py | 33 +++++++++---------- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 6c3cebbe23d5..a6f13d4f38dd 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -314,6 +314,7 @@ def __init__( # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>") self.language = language super().__init__( @@ -560,10 +561,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02): start_timestamp_position = sliced_tokens[0].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin # strip timestamp tokens from the text output - sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False) + sliced_tokens = self._preprocess_token_ids(sliced_tokens) + text = self._decode(sliced_tokens) + text = self._filter_timestamp_ids(text) offsets.append( { - "text": self._decode(sliced_tokens), + "text": text, "timestamp": ( start_timestamp_position * time_precision, end_timestamp_position * time_precision, @@ -585,9 +588,7 @@ def timestamp_ids(self, time_precision=0.02): """ return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) - def _preprocess_token_ids( - self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02 - ): + def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False): """ Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. @@ -597,24 +598,17 @@ def _preprocess_token_ids( skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be removed. - decode_with_timestamps (`bool`, *optional*, defaults to `False`): - Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be - filtered out from the token ids. - time_precision (`float`, `optional`, defaults to 0.02): - The time ratio to convert from token to time. """ if skip_special_tokens: prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) - if not decode_with_timestamps: - # filter timestamp tokens if they are contained in the vocab - timestamp_ids = self.timestamp_ids(time_precision=time_precision) - token_ids = [token for token in token_ids if token not in timestamp_ids] - return token_ids + def _filter_timestamp_ids(self, token_ids): + return re.sub(self.timestamp_pat, "", token_ids) + def decode( self, token_ids, @@ -644,6 +638,8 @@ def decode( output_offsets (`bool`, *optional*, defaults to `False`): Whether or not to output the offsets of the tokens. This should only be set if the model predicted timestamps. + time_precision (`float`, `optional`, defaults to 0.02): + The time ratio to convert from token to time. decode_with_timestamps (`bool`, *optional*, defaults to `False`): Whether or not to decode with timestamps included in the raw text. Returns: @@ -652,8 +648,6 @@ def decode( filtered_ids = self._preprocess_token_ids( token_ids, skip_special_tokens=skip_special_tokens, - decode_with_timestamps=decode_with_timestamps, - time_precision=time_precision, ) text = super().decode( @@ -668,6 +662,9 @@ def decode( text = self._decode_with_timestamps( filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens ) + else: + text = self._filter_timestamp_ids(text) + # retrieve offsets if output_offsets: offsets = self._compute_offsets(token_ids, time_precision=time_precision) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index c85b945685fa..71b741be52b3 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -15,6 +15,7 @@ """Tokenization classes for Whisper.""" import json import os +import re from functools import lru_cache from typing import List, Optional, Tuple @@ -190,6 +191,7 @@ def __init__( self.english_spelling_normalizer = None self.add_prefix_space = add_prefix_space + self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>") self.language = language self.task = task @@ -269,10 +271,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02): start_timestamp_position = sliced_tokens[0].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin # strip timestamp tokens from the text output - sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False) + sliced_tokens = self._preprocess_token_ids(sliced_tokens) + text = self._decode(sliced_tokens) + text = self._filter_timestamp_ids(text) offsets.append( { - "text": self._decode(sliced_tokens), + "text": text, "timestamp": ( start_timestamp_position * time_precision, end_timestamp_position * time_precision, @@ -296,9 +300,7 @@ def timestamp_ids(self, time_precision=0.02): return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids - def _preprocess_token_ids( - self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02 - ): + def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False): """ Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. @@ -308,24 +310,18 @@ def _preprocess_token_ids( skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be removed. - decode_with_timestamps (`bool`, *optional*, defaults to `False`): - Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be - filtered out from the token ids. - time_precision (`float`, `optional`, defaults to 0.02): - The time ratio to convert from token to time. """ if skip_special_tokens: prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) - if not decode_with_timestamps: - # filter timestamp tokens if they are contained in the vocab - timestamp_ids = self.timestamp_ids(time_precision=time_precision) - token_ids = [token for token in token_ids if token not in timestamp_ids] - return token_ids + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids + def _filter_timestamp_ids(self, token_ids): + return re.sub(self.timestamp_pat, "", token_ids) + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode def decode( self, @@ -356,6 +352,8 @@ def decode( output_offsets (`bool`, *optional*, defaults to `False`): Whether or not to output the offsets of the tokens. This should only be set if the model predicted timestamps. + time_precision (`float`, `optional`, defaults to 0.02): + The time ratio to convert from token to time. decode_with_timestamps (`bool`, *optional*, defaults to `False`): Whether or not to decode with timestamps included in the raw text. Returns: @@ -364,8 +362,6 @@ def decode( filtered_ids = self._preprocess_token_ids( token_ids, skip_special_tokens=skip_special_tokens, - decode_with_timestamps=decode_with_timestamps, - time_precision=time_precision, ) text = super().decode( @@ -380,6 +376,9 @@ def decode( text = self._decode_with_timestamps( filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens ) + else: + text = self._filter_timestamp_ids(text) + # retrieve offsets if output_offsets: offsets = self._compute_offsets(token_ids, time_precision=time_precision)