diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c1a9fc0c7805..5f68bf12f06a 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -190,6 +190,7 @@ def __init__(self, **kwargs): self.num_labels = kwargs.pop("num_labels", 2) # Tokenizer arguments TODO: eventually tokenizer and models should share the same config + self.tokenizer_class = kwargs.pop("tokenizer_class", None) self.prefix = kwargs.pop("prefix", None) self.bos_token_id = kwargs.pop("bos_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 5a48ee16f061..5f155039d81c 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -14,6 +14,7 @@ SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown" +DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" # Used to test Auto{Config, Model, Tokenizer} model_type detection. diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 626e576759bc..a29aecf325f3 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -217,6 +217,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) use_fast = kwargs.pop("use_fast", False) + + if config.tokenizer_class is not None: + if use_fast and not config.tokenizer_class.endswith("Fast"): + tokenizer_class_candidate = f"{config.tokenizer_class}Fast" + else: + tokenizer_class_candidate = config.tokenizer_class + tokenizer_class = globals().get(tokenizer_class_candidate) + if tokenizer_class is None: + raise ValueError("Tokenizer class {} does not exist or is not currently imported.") + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items(): if isinstance(config, config_class): if tokenizer_class_fast and use_fast: diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 54bfb2e13c95..524a2282492e 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -27,7 +27,13 @@ RobertaTokenizer, RobertaTokenizerFast, ) -from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401 +from transformers.configuration_auto import AutoConfig +from transformers.configuration_roberta import RobertaConfig +from transformers.testing_utils import ( + DUMMY_DIFF_TOKENIZER_IDENTIFIER, + DUMMY_UNKWOWN_IDENTIFIER, + SMALL_MODEL_IDENTIFIER, +) from transformers.tokenization_auto import TOKENIZER_MAPPING @@ -56,6 +62,14 @@ def test_tokenizer_from_model_type(self): self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) self.assertEqual(tokenizer.vocab_size, 20) + def test_tokenizer_from_tokenizer_class(self): + config = AutoConfig.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER) + self.assertIsInstance(config, RobertaConfig) + # Check that tokenizer_type ≠ model_type + tokenizer = AutoTokenizer.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER, config=config) + self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast)) + self.assertEqual(tokenizer.vocab_size, 12) + def test_tokenizer_identifier_with_correct_config(self): for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")