diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 7b1365a3e3c5..68b846da965b 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -60,7 +60,10 @@ def processor_class_from_name(class_name: str): module = importlib.import_module(f".{module_name}", "transformers.models") return getattr(module, class_name) - break + + for processor in PROCESSOR_MAPPING._extra_content.values(): + if getattr(processor, "__name__", None) == class_name: + return processor return None @@ -231,3 +234,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: " f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}" ) + + @staticmethod + def register(config_class, processor_class): + """ + Register a new processor for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + processor_class ([`FeatureExtractorMixin`]): The processor to register. + """ + PROCESSOR_MAPPING.register(config_class, processor_class) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 83c85b1a203d..dad3d5c7d613 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -21,9 +21,13 @@ from pathlib import Path from .dynamic_module_utils import custom_object_save +from .file_utils import PushToHubMixin, copy_func from .tokenization_utils_base import PreTrainedTokenizerBase +from .utils import logging +logger = logging.get_logger(__name__) + # Dynamically import the Transformers module to grab the attribute classes of the processor form their names. spec = importlib.util.spec_from_file_location( "transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent] @@ -37,7 +41,7 @@ } -class ProcessorMixin: +class ProcessorMixin(PushToHubMixin): """ This is a mixin used to provide saving/loading functionality for all processor classes. """ @@ -88,7 +92,7 @@ def __repr__(self): attributes_repr = "\n".join(attributes_repr) return f"{self.__class__.__name__}:\n{attributes_repr}" - def save_pretrained(self, save_directory): + def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): """ Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. @@ -105,7 +109,24 @@ def save_pretrained(self, save_directory): save_directory (`str` or `os.PathLike`): Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your processor to the Hugging Face model hub after saving it. + + + + Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, + which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing + folder. Pass along `temp_dir=True` to use a temporary directory instead. + + + + kwargs: + Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method. """ + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo = self._create_or_get_repo(save_directory, **kwargs) + os.makedirs(save_directory, exist_ok=True) # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # loaded from the Hub. @@ -129,6 +150,10 @@ def save_pretrained(self, save_directory): if isinstance(attribute, PreTrainedTokenizerBase): del attribute.init_kwargs["auto_map"] + if push_to_hub: + url = self._push_to_hub(repo, commit_message=commit_message) + logger.info(f"Processor pushed to the hub in this commit: {url}") + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): r""" @@ -205,3 +230,9 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs) args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) return args + + +ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) +ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( + object="processor", object_class="AutoProcessor", object_files="processor files" +) diff --git a/tests/test_processor_auto.py b/tests/test_processor_auto.py index 7cbb5b06a9dd..d4a543ee5cbb 100644 --- a/tests/test_processor_auto.py +++ b/tests/test_processor_auto.py @@ -23,7 +23,19 @@ from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError -from transformers import AutoProcessor, AutoTokenizer, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor +from transformers import ( + CONFIG_MAPPING, + FEATURE_EXTRACTOR_MAPPING, + PROCESSOR_MAPPING, + TOKENIZER_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2Processor, +) from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available from transformers.testing_utils import PASS, USER, is_staging_test from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE @@ -31,6 +43,7 @@ sys.path.append(str(Path(__file__).parent.parent / "utils")) +from test_module.custom_configuration import CustomConfig # noqa E402 from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402 from test_module.custom_processing import CustomProcessor # noqa E402 from test_module.custom_tokenization import CustomTokenizer # noqa E402 @@ -41,10 +54,12 @@ ) SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json") -SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") +SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") class AutoFeatureExtractorTest(unittest.TestCase): + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] + def test_processor_from_model_shortcut(self): processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsInstance(processor, Wav2Vec2Processor) @@ -154,6 +169,42 @@ def test_from_pretrained_dynamic_processor(self): else: self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + def test_new_processor_registration(self): + try: + AutoConfig.register("custom", CustomConfig) + AutoFeatureExtractor.register(CustomConfig, CustomFeatureExtractor) + AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer) + AutoProcessor.register(CustomConfig, CustomProcessor) + # Trying to register something existing in the Transformers library will raise an error + with self.assertRaises(ValueError): + AutoProcessor.register(Wav2Vec2Config, Wav2Vec2Processor) + + # Now that the config is registered, it can be used as any other config with the auto-API + feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) + + with tempfile.TemporaryDirectory() as tmp_dir: + vocab_file = os.path.join(tmp_dir, "vocab.txt") + with open(vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) + tokenizer = CustomTokenizer(vocab_file) + + processor = CustomProcessor(feature_extractor, tokenizer) + + with tempfile.TemporaryDirectory() as tmp_dir: + processor.save_pretrained(tmp_dir) + new_processor = AutoProcessor.from_pretrained(tmp_dir) + self.assertIsInstance(new_processor, CustomProcessor) + + finally: + if "custom" in CONFIG_MAPPING._extra_content: + del CONFIG_MAPPING._extra_content["custom"] + if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content: + del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig] + if CustomConfig in TOKENIZER_MAPPING._extra_content: + del TOKENIZER_MAPPING._extra_content[CustomConfig] + if CustomConfig in PROCESSOR_MAPPING._extra_content: + del PROCESSOR_MAPPING._extra_content[CustomConfig] + @is_staging_test class ProcessorPushToHubTester(unittest.TestCase): @@ -165,17 +216,55 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): + try: + delete_repo(token=cls._token, name="test-processor") + except HTTPError: + pass + + try: + delete_repo(token=cls._token, name="test-processor-org", organization="valid_org") + except HTTPError: + pass + try: delete_repo(token=cls._token, name="test-dynamic-processor") except HTTPError: pass + def test_push_to_hub(self): + processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) + with tempfile.TemporaryDirectory() as tmp_dir: + processor.save_pretrained( + os.path.join(tmp_dir, "test-processor"), push_to_hub=True, use_auth_token=self._token + ) + + new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor") + for k, v in processor.feature_extractor.__dict__.items(): + self.assertEqual(v, getattr(new_processor.feature_extractor, k)) + self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab()) + + def test_push_to_hub_in_organization(self): + processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) + + with tempfile.TemporaryDirectory() as tmp_dir: + processor.save_pretrained( + os.path.join(tmp_dir, "test-processor-org"), + push_to_hub=True, + use_auth_token=self._token, + organization="valid_org", + ) + + new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org") + for k, v in processor.feature_extractor.__dict__.items(): + self.assertEqual(v, getattr(new_processor.feature_extractor, k)) + self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab()) + def test_push_to_hub_dynamic_processor(self): CustomFeatureExtractor.register_for_auto_class() CustomTokenizer.register_for_auto_class() CustomProcessor.register_for_auto_class() - feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) + feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) with tempfile.TemporaryDirectory() as tmp_dir: vocab_file = os.path.join(tmp_dir, "vocab.txt")