Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,12 @@ def from_config(cls, config, **kwargs):
else:
repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
model_class.register_for_auto_class(auto_class=cls)
cls.register(config.__class__, model_class, exist_ok=True)
# This block handles the case where the user is loading a model with `trust_remote_code=True`
# but a library model exists with the same name. We don't want to override the autoclass
# mappings in this case, or all future loads of that model will be the remote code model.
if not has_local_code:
cls.register(config.__class__, model_class, exist_ok=True)
model_class.register_for_auto_class(auto_class=cls)
_ = kwargs.pop("code_revision", None)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class._from_config(config, **kwargs)
Expand Down Expand Up @@ -579,8 +583,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[s
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
)
_ = hub_kwargs.pop("code_revision", None)
cls.register(config.__class__, model_class, exist_ok=True)
model_class.register_for_auto_class(auto_class=cls)
# This block handles the case where the user is loading a model with `trust_remote_code=True`
# but a library model exists with the same name. We don't want to override the autoclass
# mappings in this case, or all future loads of that model will be the remote code model.
if not has_local_code:
cls.register(config.__class__, model_class, exist_ok=True)
model_class.register_for_auto_class(auto_class=cls)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
Expand Down