diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 5912678e12d2..2f93efb1f57e 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -208,6 +208,9 @@ def from_config(cls, config, **kwargs): trust_remote_code = kwargs.pop("trust_remote_code", None) has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map has_local_code = type(config) in cls._model_mapping + explicit_local_code = has_local_code and not _get_model_class( + config, cls._model_mapping + ).__module__.startswith("transformers.") if has_remote_code: class_ref = config.auto_map[cls.__name__] if "--" in class_ref: @@ -218,7 +221,7 @@ def from_config(cls, config, **kwargs): trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: if "--" in class_ref: repo_id, class_ref = class_ref.split("--") else: @@ -233,7 +236,7 @@ def from_config(cls, config, **kwargs): _ = kwargs.pop("code_revision", None) model_class = add_generation_mixin_to_remote_model(model_class) return model_class._from_config(config, **kwargs) - elif type(config) in cls._model_mapping: + elif has_local_code: model_class = _get_model_class(config, cls._model_mapping) return model_class._from_config(config, **kwargs) @@ -342,6 +345,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map has_local_code = type(config) in cls._model_mapping + explicit_local_code = has_local_code and not _get_model_class( + config, cls._model_mapping + ).__module__.startswith("transformers.") upstream_repo = None if has_remote_code: class_ref = config.auto_map[cls.__name__] @@ -359,7 +365,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], # Set the adapter kwargs kwargs["adapter_kwargs"] = adapter_kwargs - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: model_class = get_class_from_dynamic_module( class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs ) @@ -374,7 +380,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) - elif type(config) in cls._model_mapping: + elif has_local_code: model_class = _get_model_class(config, cls._model_mapping) if model_class.config_class == config.sub_configs.get("text_config", None): config = config.get_text_config() diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 78d66d52d8d6..b543a963f010 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -1470,6 +1470,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], config_dict, unused_kwargs = PreTrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING + explicit_local_code = has_local_code and not CONFIG_MAPPING[config_dict["model_type"]].__module__.startswith( + "transformers." + ) if has_remote_code: class_ref = config_dict["auto_map"]["AutoConfig"] if "--" in class_ref: @@ -1480,7 +1483,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: config_class = get_class_from_dynamic_module( class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs ) diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 475328188bc9..69745c5847be 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -337,6 +337,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): has_remote_code = feature_extractor_auto_map is not None has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING + explicit_local_code = has_local_code and not ( + feature_extractor_class or FEATURE_EXTRACTOR_MAPPING[type(config)] + ).__module__.startswith("transformers.") if has_remote_code: if "--" in feature_extractor_auto_map: upstream_repo = feature_extractor_auto_map.split("--")[0] @@ -346,7 +349,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: feature_extractor_class = get_class_from_dynamic_module( feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs ) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 1baa1fb64813..2a9e2720106e 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -724,6 +724,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): # Handle remote code has_remote_code = image_processor_auto_map is not None has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING + explicit_local_code = has_local_code and not ( + image_processor_class or IMAGE_PROCESSOR_MAPPING[type(config)] + ).__module__.startswith("transformers.") if has_remote_code: class_ref = _resolve_auto_map_class_ref(image_processor_auto_map, backend) upstream_repo = class_ref.split("--")[0] if "--" in class_ref else None @@ -731,7 +734,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) _ = kwargs.pop("code_revision", None) image_processor_class.register_for_auto_class() diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c3a1a1745762..cb4c9c61d601 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -403,6 +403,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): has_remote_code = processor_auto_map is not None has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING + explicit_local_code = has_local_code and not ( + processor_class or PROCESSOR_MAPPING[type(config)] + ).__module__.startswith("transformers.") if has_remote_code: if "--" in processor_auto_map: upstream_repo = processor_auto_map.split("--")[0] @@ -412,7 +415,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: processor_class = get_class_from_dynamic_module( processor_auto_map, pretrained_model_name_or_path, **kwargs ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 8a199d4487fc..0e50f2a0041d 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -738,7 +738,13 @@ def from_pretrained( or tokenizer_class_from_name(tokenizer_config_class + "Fast") is not None ) ) - + explicit_local_code = has_local_code and ( + tokenizer_config_class is not None + and not ( + tokenizer_class_from_name(tokenizer_config_class).__module__.startswith("transformers.") + and tokenizer_class_from_name(tokenizer_config_class + "Fast").__module__.startswith("transformers.") + ) + ) # V5: Skip remote tokenizer for custom models with incorrect hub tokenizer class if has_remote_code and config_model_type in MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS: has_remote_code = False @@ -758,7 +764,7 @@ def from_pretrained( trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: # BC v5: register *Fast aliases before remote code loads. if tokenizer_config_class: tokenizer_class_from_name(tokenizer_config_class.removesuffix("Fast")) diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py index ad062387fef6..3b7d36a9aeb2 100644 --- a/src/transformers/models/auto/video_processing_auto.py +++ b/src/transformers/models/auto/video_processing_auto.py @@ -364,6 +364,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): has_remote_code = video_processor_auto_map is not None has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING + explicit_local_code = has_local_code and not ( + video_processor_class or VIDEO_PROCESSOR_MAPPING[type(config)] + ).__module__.startswith("transformers.") if has_remote_code: if "--" in video_processor_auto_map: upstream_repo = video_processor_auto_map.split("--")[0] @@ -373,7 +376,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo ) - if has_remote_code and trust_remote_code: + if has_remote_code and trust_remote_code and not explicit_local_code: class_ref = video_processor_auto_map video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) _ = kwargs.pop("code_revision", None) diff --git a/tests/models/auto/test_configuration_auto.py b/tests/models/auto/test_configuration_auto.py index c44ed9ef2733..2ccbde235078 100644 --- a/tests/models/auto/test_configuration_auto.py +++ b/tests/models/auto/test_configuration_auto.py @@ -134,7 +134,12 @@ def __init__(self, **kwargs): config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=False) self.assertEqual(config.__class__.__name__, "NewModelConfigLocal") - # If remote is enabled, we load from the Hub + # If remote code is enabled but the user explicitly registered the local one, we load the local one. + config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) + self.assertEqual(config.__class__.__name__, "NewModelConfigLocal") + + # If remote code is enabled but local code originated from transformers, we load the remote one. + NewModelConfigLocal.__module__ = "transformers.models.new_model.configuration_new_model" config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) self.assertEqual(config.__class__.__name__, "NewModelConfig") diff --git a/tests/models/auto/test_feature_extraction_auto.py b/tests/models/auto/test_feature_extraction_auto.py index 7ecdf1756bb4..0d9d5546ec44 100644 --- a/tests/models/auto/test_feature_extraction_auto.py +++ b/tests/models/auto/test_feature_extraction_auto.py @@ -174,7 +174,15 @@ class NewFeatureExtractor(Wav2Vec2FeatureExtractor): self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor") self.assertTrue(feature_extractor.is_local) - # If remote is enabled, we load from the Hub + # If remote code is enabled but the user explicitly registered the local one, we load the local one. + feature_extractor = AutoFeatureExtractor.from_pretrained( + "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True + ) + self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor") + self.assertTrue(feature_extractor.is_local) + + # If remote code is enabled but local code originated from transformers, we load the remote one. + NewFeatureExtractor.__module__ = "transformers.models.custom.configuration_custom" feature_extractor = AutoFeatureExtractor.from_pretrained( "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True ) diff --git a/tests/models/auto/test_image_processing_auto.py b/tests/models/auto/test_image_processing_auto.py index 583836c2b099..886292830678 100644 --- a/tests/models/auto/test_image_processing_auto.py +++ b/tests/models/auto/test_image_processing_auto.py @@ -261,7 +261,15 @@ class NewImageProcessor(CLIPImageProcessor): self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor") self.assertTrue(image_processor.is_local) - # If remote is enabled, we load from the Hub + # If remote code is enabled but the user explicitly registered the local one, we load the local one. + image_processor = AutoImageProcessor.from_pretrained( + "hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True + ) + self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor") + self.assertTrue(image_processor.is_local) + + # If remote code is enabled but local code originated from transformers, we load the remote one. + NewImageProcessor.__module__ = "transformers.models.custom.configuration_custom" image_processor = AutoImageProcessor.from_pretrained( "hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True ) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 14a7862a68c1..5c0621d3cafb 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -459,7 +459,13 @@ class NewModel(BertModel): model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=False) self.assertEqual(model.config.__class__.__name__, "NewModelConfigLocal") - # If remote is enabled, we load from the Hub + # If remote code is enabled but the user explicitly registered the local one, we load the local one. + model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) + self.assertEqual(model.config.__class__.__name__, "NewModelConfigLocal") + + # If remote code is enabled but local code originated from transformers, we load the remote one. + NewModelConfigLocal.__module__ = "transformers.models.new_model.configuration_new_model" + NewModel.__module__ = "transformers.models.new_model.modeling_new_model" model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) self.assertEqual(model.config.__class__.__name__, "NewModelConfig") diff --git a/tests/models/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py index 50f73d8851a4..51a9084d52be 100644 --- a/tests/models/auto/test_processor_auto.py +++ b/tests/models/auto/test_processor_auto.py @@ -324,7 +324,19 @@ def __init__(self, feature_extractor, tokenizer): self.assertFalse(processor.feature_extractor.special_attribute_present) self.assertFalse(processor.tokenizer.special_attribute_present) - # If remote is enabled, we load from the Hub. + # If remote code is enabled but the user explicitly registered the local one, we load the local one. + processor = AutoProcessor.from_pretrained( + "hf-internal-testing/test_dynamic_processor_updated", trust_remote_code=True + ) + self.assertEqual(processor.__class__.__name__, "NewProcessor") + self.assertFalse(processor.special_attribute_present) + self.assertFalse(processor.feature_extractor.special_attribute_present) + self.assertFalse(processor.tokenizer.special_attribute_present) + + # If remote code is enabled but local code originated from transformers, we load the remote one. + NewFeatureExtractor.__module__ = "transformers.models.custom.feature_extraction_custom" + NewTokenizer.__module__ = "transformers.models.custom.tokenization_custom" + NewProcessor.__module__ = "transformers.models.custom.configuration_custom" processor = AutoProcessor.from_pretrained( "hf-internal-testing/test_dynamic_processor_updated", trust_remote_code=True ) diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 12ffcb0b4c21..2bc79a3f82d6 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -503,6 +503,15 @@ class NewTokenizer(BertTokenizer): self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") self.assertFalse(tokenizer.special_attribute_present) + # If remote code is enabled but the user explicitly registered the local one, we load the local one. + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, use_fast=False + ) + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + self.assertFalse(tokenizer.special_attribute_present) + + # If remote code is enabled but local code originated from transformers, we load the remote one. + NewTokenizer.__module__ = "transformers.models.custom.configuration_custom" tokenizer = AutoTokenizer.from_pretrained( "hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, use_fast=False ) diff --git a/tests/models/auto/test_video_processing_auto.py b/tests/models/auto/test_video_processing_auto.py index c58345027e31..baccddbdc652 100644 --- a/tests/models/auto/test_video_processing_auto.py +++ b/tests/models/auto/test_video_processing_auto.py @@ -226,7 +226,15 @@ class NewVideoProcessor(LlavaOnevisionVideoProcessor): self.assertEqual(video_processor.__class__.__name__, "NewVideoProcessor") self.assertTrue(video_processor.is_local) - # If remote is enabled, we load from the Hub + # If remote code is enabled but the user explicitly registered the local one, we load the local one. + video_processor = AutoVideoProcessor.from_pretrained( + "hf-internal-testing/test_dynamic_video_processor", trust_remote_code=True + ) + self.assertEqual(video_processor.__class__.__name__, "NewVideoProcessor") + self.assertTrue(video_processor.is_local) + + # If remote code is enabled but local code originated from transformers, we load the remote one. + NewVideoProcessor.__module__ = "transformers.models.custom.configuration_custom" video_processor = AutoVideoProcessor.from_pretrained( "hf-internal-testing/test_dynamic_video_processor", trust_remote_code=True )