diff --git a/src/transformers/models/esm/tokenization_esm.py b/src/transformers/models/esm/tokenization_esm.py index 065eaae1d505..478527c0ecd1 100644 --- a/src/transformers/models/esm/tokenization_esm.py +++ b/src/transformers/models/esm/tokenization_esm.py @@ -14,10 +14,9 @@ # limitations under the License. """Tokenization classes for ESM.""" import os -from typing import List, Optional, Union +from typing import List, Optional from ...tokenization_utils import PreTrainedTokenizer -from ...tokenization_utils_base import AddedToken from ...utils import logging @@ -91,11 +90,10 @@ def _convert_token_to_id(self, token: str) -> int: def _tokenize(self, text, **kwargs): return text.split() - def get_vocab_size(self, with_added_tokens=False): - return len(self._id_to_token) - def get_vocab(self): - return {token: i for i, token in enumerate(self.all_tokens)} + base_vocab = self._token_to_id.copy() + base_vocab.update(self.added_tokens_encoder) + return base_vocab def token_to_id(self, token: str) -> int: return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) @@ -156,7 +154,4 @@ def save_vocabulary(self, save_directory, filename_prefix): @property def vocab_size(self) -> int: - return self.get_vocab_size(with_added_tokens=False) - - def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: - return super()._add_tokens(new_tokens, special_tokens=True) + return len(self.all_tokens) diff --git a/tests/models/esm/test_tokenization_esm.py b/tests/models/esm/test_tokenization_esm.py index 539baaf34150..aac03b535edc 100644 --- a/tests/models/esm/test_tokenization_esm.py +++ b/tests/models/esm/test_tokenization_esm.py @@ -87,3 +87,25 @@ def test_tokenize_special_tokens(self): self.assertEqual(len(token_2), 1) self.assertEqual(token_1[0], SPECIAL_TOKEN_1) self.assertEqual(token_2[0], SPECIAL_TOKEN_2) + + def test_add_tokens(self): + tokenizer = self.tokenizer_class(self.vocab_file) + + vocab_size = len(tokenizer) + self.assertEqual(tokenizer.add_tokens(""), 0) + self.assertEqual(tokenizer.add_tokens("testoken"), 1) + self.assertEqual(tokenizer.add_tokens(["testoken1", "testtoken2"]), 2) + self.assertEqual(len(tokenizer), vocab_size + 3) + + self.assertEqual(tokenizer.add_special_tokens({}), 0) + self.assertEqual(tokenizer.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2) + self.assertRaises(AssertionError, tokenizer.add_special_tokens, {"additional_special_tokens": ""}) + self.assertEqual(tokenizer.add_special_tokens({"additional_special_tokens": [""]}), 1) + self.assertEqual( + tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}), 2 + ) + self.assertIn("", tokenizer.special_tokens_map["additional_special_tokens"]) + self.assertIsInstance(tokenizer.special_tokens_map["additional_special_tokens"], list) + self.assertGreaterEqual(len(tokenizer.special_tokens_map["additional_special_tokens"]), 2) + + self.assertEqual(len(tokenizer), vocab_size + 8)