diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 45fd5ed4e78a..b3eccddf4103 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -15,6 +15,7 @@ """Tokenization classes for Whisper.""" import json import os +from functools import lru_cache from typing import TYPE_CHECKING, List, Optional, Tuple, Union import numpy as np @@ -546,6 +547,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin + # strip timestamp tokens from the text output + sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False) offsets.append( { "text": self._decode(sliced_tokens), @@ -559,6 +562,47 @@ def _compute_offsets(self, token_ids, time_precision=0.02): return offsets + @lru_cache + def timestamp_ids(self, time_precision=0.02): + """ + Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache. + + Args: + time_precision (`float`, `optional`, defaults to 0.02): + The time ratio to convert from token to time. + """ + return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) + + def _preprocess_token_ids( + self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02 + ): + """ + Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be + removed. + decode_with_timestamps (`bool`, *optional*, defaults to `False`): + Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be + filtered out from the token ids. + time_precision (`float`, `optional`, defaults to 0.02): + The time ratio to convert from token to time. + """ + if skip_special_tokens: + prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") + decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") + token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) + + if not decode_with_timestamps: + # filter timestamp tokens if they are contained in the vocab + timestamp_ids = self.timestamp_ids(time_precision=time_precision) + token_ids = [token for token in token_ids if token not in timestamp_ids] + + return token_ids + def decode( self, token_ids, @@ -593,33 +637,40 @@ def decode( Returns: `str`: The decoded sentence. """ - text = super().decode( + filtered_ids = self._preprocess_token_ids( token_ids, skip_special_tokens=skip_special_tokens, + decode_with_timestamps=decode_with_timestamps, + time_precision=time_precision, + ) + + text = super().decode( + filtered_ids, + skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, + decode_with_timestamps=decode_with_timestamps, **kwargs, ) if decode_with_timestamps: + # legacy method to decode timestamps when not included in the tokenizer vocabulary text = self._decode_with_timestamps( - token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens + filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens ) # retrieve offsets if output_offsets: - offsets = None offsets = self._compute_offsets(token_ids, time_precision=time_precision) return {"text": text, "offsets": offsets} return text def _decode( - self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + normalize: bool = False, + decode_with_timestamps: bool = False, + **kwargs, ) -> str: self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) - - if skip_special_tokens: - prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") - decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") - token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) - filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) # To avoid mixing byte-level and unicode for byte-level BPT diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 689da150009c..4ad500bbf1c0 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -15,6 +15,7 @@ """Tokenization classes for Whisper.""" import json import os +from functools import lru_cache from typing import TYPE_CHECKING, List, Optional, Tuple import numpy as np @@ -255,6 +256,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin + # strip timestamp tokens from the text output + sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False) offsets.append( { "text": self._decode(sliced_tokens), @@ -268,6 +271,49 @@ def _compute_offsets(self, token_ids, time_precision=0.02): return offsets + @lru_cache + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.timestamp_ids + def timestamp_ids(self, time_precision=0.02): + """ + Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache. + + Args: + time_precision (`float`, `optional`, defaults to 0.02): + The time ratio to convert from token to time. + """ + return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) + + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids + def _preprocess_token_ids( + self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02 + ): + """ + Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be + removed. + decode_with_timestamps (`bool`, *optional*, defaults to `False`): + Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be + filtered out from the token ids. + time_precision (`float`, `optional`, defaults to 0.02): + The time ratio to convert from token to time. + """ + if skip_special_tokens: + prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") + decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") + token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) + + if not decode_with_timestamps: + # filter timestamp tokens if they are contained in the vocab + timestamp_ids = self.timestamp_ids(time_precision=time_precision) + token_ids = [token for token in token_ids if token not in timestamp_ids] + + return token_ids + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode def decode( self, @@ -303,29 +349,32 @@ def decode( Returns: `str`: The decoded sentence. """ - text = super().decode( + filtered_ids = self._preprocess_token_ids( token_ids, skip_special_tokens=skip_special_tokens, + decode_with_timestamps=decode_with_timestamps, + time_precision=time_precision, + ) + + text = super().decode( + filtered_ids, + skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, + decode_with_timestamps=decode_with_timestamps, **kwargs, ) if decode_with_timestamps: + # legacy method to decode timestamps when not included in the tokenizer vocabulary text = self._decode_with_timestamps( - token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens + filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens ) # retrieve offsets if output_offsets: - offsets = None offsets = self._compute_offsets(token_ids, time_precision=time_precision) return {"text": text, "offsets": offsets} return text def _decode(self, *args, normalize: bool = False, **kwargs) -> str: - if kwargs["skip_special_tokens"]: - prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") - decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") - kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id) - text = super()._decode(*args, **kwargs) if normalize: diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 3ded58960f6e..9ab29d29d1de 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -52,14 +52,13 @@ def test_convert_token_and_id(self): self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id) self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) - @unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time") def test_get_vocab(self): vocab_keys = list(self.get_tokenizer().get_vocab().keys()) self.assertEqual(vocab_keys[0], "!") self.assertEqual(vocab_keys[1], '"') - self.assertEqual(vocab_keys[-1], "<|notimestamps|>") - self.assertEqual(len(vocab_keys), 50364) + self.assertEqual(vocab_keys[-1], "<|30.00|>") + self.assertEqual(len(vocab_keys), 51865) def test_vocab_size(self): self.assertEqual(self.get_tokenizer().vocab_size, 50258) @@ -117,7 +116,6 @@ def test_tokenizer_integration(self): expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False ) - @unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time") def test_output_offsets(self): tokenizer = self.get_tokenizer() previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612] @@ -400,7 +398,6 @@ def test_batch_encoding_decoding(self): transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True) self.assertListEqual(batch, transcription) - @unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time") def test_offset_decoding(self): multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny") # fmt: off