diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 6a5bff3679f8..df13a029a6c6 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1077,7 +1077,7 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: def _decode( self, - token_ids: List[int], + token_ids: Union[int, List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, spaces_between_special_tokens: bool = True, @@ -1086,6 +1086,10 @@ def _decode( self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + # If given is a single id, prevents splitting the string in upcoming loop + if isinstance(filtered_tokens, str): + filtered_tokens = [filtered_tokens] + legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | { token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size } @@ -1096,7 +1100,7 @@ def _decode( current_sub_text = [] # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string for token in filtered_tokens: - if skip_special_tokens and token in self.all_special_ids: + if skip_special_tokens and token in self.all_special_tokens: continue if token in legacy_added_tokens: if current_sub_text: diff --git a/tests/tokenization/test_tokenization_utils.py b/tests/tokenization/test_tokenization_utils.py index b43923df84d7..2c8f71ba9772 100644 --- a/tests/tokenization/test_tokenization_utils.py +++ b/tests/tokenization/test_tokenization_utils.py @@ -253,6 +253,71 @@ def test_padding_accepts_tensors(self): self.assertTrue(isinstance(batch["input_ids"], np.ndarray)) self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]]) + @require_tokenizers + def test_decoding_single_token(self): + for tokenizer_class in [BertTokenizer, BertTokenizerFast]: + with self.subTest(f"{tokenizer_class}"): + tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased") + + token_id = 2300 + decoded_flat = tokenizer.decode(token_id) + decoded_list = tokenizer.decode([token_id]) + + self.assertEqual(decoded_flat, "Force") + self.assertEqual(decoded_list, "Force") + + token_id = 0 + decoded_flat = tokenizer.decode(token_id) + decoded_list = tokenizer.decode([token_id]) + + self.assertEqual(decoded_flat, "[PAD]") + self.assertEqual(decoded_list, "[PAD]") + + last_item_id = tokenizer.vocab_size - 1 + decoded_flat = tokenizer.decode(last_item_id) + decoded_list = tokenizer.decode([last_item_id]) + + self.assertEqual(decoded_flat, "##:") + self.assertEqual(decoded_list, "##:") + + @require_tokenizers + def test_decoding_skip_special_tokens(self): + for tokenizer_class in [BertTokenizer, BertTokenizerFast]: + with self.subTest(f"{tokenizer_class}"): + tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased") + tokenizer.add_tokens(["ஐ"], special_tokens=True) + + # test special token with other tokens, skip the special tokens + sentence = "This is a beautiful flower ஐ" + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=True) + self.assertEqual(decoded_sent, "This is a beautiful flower") + + # test special token with other tokens, do not skip the special tokens + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=False) + self.assertEqual(decoded_sent, "[CLS] This is a beautiful flower ஐ [SEP]") + + # test special token stand alone, skip the special tokens + sentence = "ஐ" + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=True) + self.assertEqual(decoded_sent, "") + + # test special token stand alone, do not skip the special tokens + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=False) + self.assertEqual(decoded_sent, "[CLS] ஐ [SEP]") + + # test single special token alone, skip + pad_id = 0 + decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=True) + self.assertEqual(decoded_sent, "") + + # test single special token alone, do not skip + decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=False) + self.assertEqual(decoded_sent, "[PAD]") + @require_torch def test_padding_accepts_tensors_pt(self): import torch