diff --git a/src/transformers/models/albert/tokenization_albert.py b/src/transformers/models/albert/tokenization_albert.py index cfcfcd9daa1d..5bebb936cf7d 100644 --- a/src/transformers/models/albert/tokenization_albert.py +++ b/src/transformers/models/albert/tokenization_albert.py @@ -250,7 +250,23 @@ def _convert_id_to_token(self, index): return self.sp_model.IdToPiece(index) def convert_tokens_to_string(self, tokens): - return self.sp_model.decode(tokens) + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None diff --git a/src/transformers/models/barthez/tokenization_barthez.py b/src/transformers/models/barthez/tokenization_barthez.py index 5f12adb7a336..2e58db113e15 100644 --- a/src/transformers/models/barthez/tokenization_barthez.py +++ b/src/transformers/models/barthez/tokenization_barthez.py @@ -263,6 +263,25 @@ def _convert_id_to_token(self, index): return self.fairseq_ids_to_tokens[index] return self.sp_model.IdToPiece(index) + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None @@ -278,10 +297,6 @@ def __setstate__(self, d): self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(self.vocab_file) - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (strings for sub-words) in a single string.""" - return self.sp_model.decode(tokens) - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") diff --git a/src/transformers/models/bert_generation/tokenization_bert_generation.py b/src/transformers/models/bert_generation/tokenization_bert_generation.py index 2ff9382a7b5b..711dcdf50c25 100644 --- a/src/transformers/models/bert_generation/tokenization_bert_generation.py +++ b/src/transformers/models/bert_generation/tokenization_bert_generation.py @@ -151,8 +151,17 @@ def _convert_id_to_token(self, index): def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" - out_string = self.sp_model.decode_pieces(tokens) - return out_string + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): diff --git a/src/transformers/models/big_bird/tokenization_big_bird.py b/src/transformers/models/big_bird/tokenization_big_bird.py index f39aa29d0c03..47c00fa7c2fa 100644 --- a/src/transformers/models/big_bird/tokenization_big_bird.py +++ b/src/transformers/models/big_bird/tokenization_big_bird.py @@ -16,6 +16,7 @@ import os +import re from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple @@ -182,8 +183,65 @@ def _convert_id_to_token(self, index): def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" - out_string = self.sp_model.decode_pieces(tokens) - return out_string + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + spaces_between_special_tokens: bool = True, + **kwargs + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + + # Mimic the behavior of the Rust tokenizer: + # No space before [MASK] and [SEP] + if spaces_between_special_tokens: + text = re.sub(r" (\[(MASK|SEP)\])", r"\1", " ".join(sub_texts)) + else: + text = "".join(sub_texts) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): diff --git a/src/transformers/models/camembert/tokenization_camembert.py b/src/transformers/models/camembert/tokenization_camembert.py index 60394148053e..f5988fd9d784 100644 --- a/src/transformers/models/camembert/tokenization_camembert.py +++ b/src/transformers/models/camembert/tokenization_camembert.py @@ -261,6 +261,25 @@ def _convert_id_to_token(self, index): return self.fairseq_ids_to_tokens[index] return self.sp_model.IdToPiece(index - self.fairseq_offset) + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None @@ -276,10 +295,6 @@ def __setstate__(self, d): self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(self.vocab_file) - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (strings for sub-words) in a single string.""" - return self.sp_model.decode(tokens) - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py index 9ac28c82cd61..fc259dd7d5ee 100644 --- a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py +++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py @@ -146,7 +146,9 @@ def __init__( self.do_lower_case = do_lower_case self.split_by_punct = split_by_punct self.vocab_file = vocab_file - self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs) + self._tokenizer = SPMTokenizer( + vocab_file, self.all_special_tokens, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs + ) @property def vocab_size(self): @@ -291,7 +293,9 @@ class SPMTokenizer: BPE-dropout. """ - def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None): + def __init__( + self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None + ): self.split_by_punct = split_by_punct self.vocab_file = vocab_file self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs @@ -312,6 +316,7 @@ def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[D # self.vocab['[UNK]'] = 3 self.spm = spm + self.special_tokens = special_tokens def __getstate__(self): state = self.__dict__.copy() @@ -339,7 +344,22 @@ def convert_ids_to_tokens(self, ids): def decode(self, tokens, start=-1, end=-1, raw_text=None): if raw_text is None: - return self.spm.decode_pieces([t for t in tokens]) + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.spm.decode_pieces(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.spm.decode_pieces(current_sub_tokens) + return out_string.strip() else: words = self.split_to_words(raw_text) word_tokens = [self.tokenize(w) for w in words] diff --git a/src/transformers/models/fnet/tokenization_fnet.py b/src/transformers/models/fnet/tokenization_fnet.py index 6143a9b08f2f..e7e3adfd793a 100644 --- a/src/transformers/models/fnet/tokenization_fnet.py +++ b/src/transformers/models/fnet/tokenization_fnet.py @@ -15,6 +15,7 @@ """ Tokenization classes for FNet model.""" import os +import re import unicodedata from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple @@ -213,7 +214,66 @@ def _convert_id_to_token(self, index): return self.sp_model.IdToPiece(index) def convert_tokens_to_string(self, tokens): - return self.sp_model.decode(tokens) + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + spaces_between_special_tokens: bool = True, + **kwargs + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + + # Mimic the behavior of the Rust tokenizer: + # No space after + if spaces_between_special_tokens: + text = re.sub(r"() ", r"\1", " ".join(sub_texts)) + else: + text = "".join(sub_texts) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py index fff596046e3f..984d05cd582d 100644 --- a/src/transformers/models/m2m_100/tokenization_m2m_100.py +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -218,9 +218,19 @@ def _convert_id_to_token(self, index: int) -> str: return self.id_to_lang_token[index] return self.decoder.get(index, self.unk_token) - def convert_tokens_to_string(self, tokens: List[str]) -> str: - """Converts a sequence of tokens (strings for sub-words) in a single string.""" - return self.sp_model.decode(tokens) + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index 66eb5a44c5bf..c688733321be 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -265,10 +265,18 @@ def decode(self, token_ids, **kwargs): def convert_tokens_to_string(self, tokens: List[str]) -> str: """Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise""" - if self._decode_use_source_tokenizer: - return self.spm_source.DecodePieces(tokens) - else: - return self.spm_target.DecodePieces(tokens) + sp_model = self.spm_source if self._decode_use_source_tokenizer else self.spm_target + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += sp_model.decode_pieces(current_sub_tokens) + token + " " + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += sp_model.decode_pieces(current_sub_tokens) + return out_string.strip() def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: """Build model inputs from a sequence by appending eos_token_id.""" diff --git a/src/transformers/models/mbart50/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py index 707a97734927..0a331b283760 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50.py +++ b/src/transformers/models/mbart50/tokenization_mbart50.py @@ -232,9 +232,24 @@ def _convert_id_to_token(self, index: int) -> str: return self.fairseq_ids_to_tokens[index] return self.sp_model.IdToPiece(index - self.fairseq_offset) - def convert_tokens_to_string(self, tokens: List[str]) -> str: - """Converts a sequence of tokens (strings for sub-words) in a single string.""" - return self.sp_model.decode(tokens) + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py index b4d1cdc19804..77127125bb48 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus.py +++ b/src/transformers/models/pegasus/tokenization_pegasus.py @@ -231,8 +231,17 @@ def _convert_id_to_token(self, index: int) -> str: def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" - out_string = self.sp_model.decode_pieces(tokens) - return out_string + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() def num_special_tokens_to_add(self, pair=False): """Just EOS""" diff --git a/src/transformers/models/reformer/tokenization_reformer.py b/src/transformers/models/reformer/tokenization_reformer.py index d5d73f3e451f..814d5ed6cde1 100644 --- a/src/transformers/models/reformer/tokenization_reformer.py +++ b/src/transformers/models/reformer/tokenization_reformer.py @@ -158,8 +158,17 @@ def _convert_id_to_token(self, index): def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" - out_string = self.sp_model.decode_pieces(tokens) - return out_string + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): diff --git a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py index e1bc681499f7..843c79e397b8 100644 --- a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py +++ b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py @@ -190,11 +190,19 @@ def _convert_id_to_token(self, index: int) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens (strings for sub-words) in a single string.""" - out_string = self.sp_model.decode(tokens) - - if self.do_upper_case: - out_string = out_string.upper() - return out_string + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + decoded = self.sp_model.decode(current_sub_tokens) + out_string += (decoded.upper() if self.do_upper_case else decoded) + token + " " + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + decoded = self.sp_model.decode(current_sub_tokens) + out_string += decoded.upper() if self.do_upper_case else decoded + return out_string.strip() def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: """Build model inputs from a sequence by appending eos_token_id.""" diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index 2dbc788374dc..5d016ab7d835 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -311,14 +311,19 @@ def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" current_sub_tokens = [] out_string = "" + prev_is_special = False for token in tokens: # make sure that special tokens are not decoded using sentencepiece model if token in self.all_special_tokens: - out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " " + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True current_sub_tokens = [] else: current_sub_tokens.append(token) - out_string += self.sp_model.decode_pieces(current_sub_tokens) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) return out_string.strip() def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: diff --git a/src/transformers/models/xlnet/tokenization_xlnet.py b/src/transformers/models/xlnet/tokenization_xlnet.py index 920a9f5cb74c..9dc6fd245964 100644 --- a/src/transformers/models/xlnet/tokenization_xlnet.py +++ b/src/transformers/models/xlnet/tokenization_xlnet.py @@ -250,6 +250,46 @@ def convert_tokens_to_string(self, tokens): out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() return out_string + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + spaces_between_special_tokens: bool = True, + **kwargs + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + + # Mimic the behavior of the Rust tokenizer: + # By default, there are no spaces between special tokens + text = "".join(sub_texts) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/tests/models/deberta_v2/test_tokenization_deberta_v2.py b/tests/models/deberta_v2/test_tokenization_deberta_v2.py index c84034c7f0bc..f2831315e5c2 100644 --- a/tests/models/deberta_v2/test_tokenization_deberta_v2.py +++ b/tests/models/deberta_v2/test_tokenization_deberta_v2.py @@ -37,7 +37,7 @@ def setUp(self): super().setUp() # We have a SentencePiece fixture for testing - tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB) + tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB, unk_token="") tokenizer.save_pretrained(self.tmpdirname) def get_input_output_texts(self, tokenizer): @@ -55,7 +55,6 @@ def test_convert_token_and_id(self): def test_get_vocab(self): vocab_keys = list(self.get_tokenizer().get_vocab().keys()) - self.assertEqual(vocab_keys[0], "") self.assertEqual(vocab_keys[1], "") self.assertEqual(vocab_keys[-1], "[PAD]") @@ -80,6 +79,14 @@ def test_do_lower_case(self): self.assertListEqual(rust_tokens, tokens_target) + @unittest.skip("There is an inconsistency between slow and fast tokenizer due to a bug in the fast one.") + def test_sentencepiece_tokenize_and_convert_tokens_to_string(self): + pass + + @unittest.skip("There is an inconsistency between slow and fast tokenizer due to a bug in the fast one.") + def test_sentencepiece_tokenize_and_decode(self): + pass + def test_split_by_punct(self): # fmt: off sequence = "I was born in 92000, and this is falsé." diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index 545e896e9ed6..ffc58fe0abee 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -1946,3 +1946,11 @@ def test_layoutxlm_integration_test(self): @unittest.skip("Doesn't support another framework than PyTorch") def test_np_encode_plus_sent_to_model(self): pass + + @unittest.skip("Doesn't use SentencePiece") + def test_sentencepiece_tokenize_and_convert_tokens_to_string(self): + pass + + @unittest.skip("Doesn't use SentencePiece") + def test_sentencepiece_tokenize_and_decode(self): + pass diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 6ece29b71844..8ca460449e24 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -385,6 +385,33 @@ def test_sentencepiece_tokenize_and_convert_tokens_to_string(self): self.assertEqual(reverse_text, text) + special_tokens = tokenizer.all_special_tokens + special_tokens_string = tokenizer.convert_tokens_to_string(special_tokens) + for special_token in special_tokens: + self.assertIn(special_token, special_tokens_string) + + if self.test_rust_tokenizer: + rust_tokenizer = self.get_rust_tokenizer() + special_tokens_string_rust = rust_tokenizer.convert_tokens_to_string(special_tokens) + self.assertEqual(special_tokens_string, special_tokens_string_rust) + + def test_sentencepiece_tokenize_and_decode(self): + if not self.test_sentencepiece: + return + + text = "This is text to test the tokenizer." + if self.test_rust_tokenizer: + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + slow_ids = tokenizer(text).input_ids + fast_ids = rust_tokenizer(text).input_ids + self.assertEqual(slow_ids, fast_ids) + + slow_decoded = tokenizer.decode(slow_ids) + fast_decoded = rust_tokenizer.decode(slow_ids) + self.assertEqual(slow_decoded, fast_decoded) + def test_subword_regularization_tokenizer(self) -> None: if not self.test_sentencepiece: return