diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 826f1a611..85091cb4a 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -23,7 +23,7 @@ from . import __MODEL_HUB_ORGANIZATION__ from .evaluation import SentenceEvaluator -from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path +from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path, save_to_hub_args_decorator from .models import Transformer, Pooling from .model_card_templates import ModelCardTemplate from . import __version__ @@ -471,6 +471,7 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_ with open(os.path.join(path, "README.md"), "w", encoding='utf8') as fOut: fOut.write(model_card.strip()) + @save_to_hub_args_decorator def save_to_hub(self, repo_id: str, organization: Optional[str] = None, diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index a7fca5be1..da9b36a56 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -1,3 +1,4 @@ +import functools import requests from torch import Tensor, device from typing import List, Callable @@ -483,3 +484,22 @@ def load_dir_path(model_name_or_path: str, directory: str, token: Optional[Union download_kwargs["local_files_only"] = True repo_path = snapshot_download(**download_kwargs) return os.path.join(repo_path, directory) + + +def save_to_hub_args_decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + # If repo_id not already set, use repo_name + repo_name = kwargs.pop("repo_name", None) + if repo_name and "repo_id" not in kwargs: + logger.warning( + "Providing a `repo_name` keyword argument to `save_to_hub` is deprecated, please use `repo_id` instead." + ) + kwargs["repo_id"] = repo_name + + # If positional args are used, adjust for the new "token" keyword argument + if len(args) >= 2: + args = (*args[:2], None, *args[2:]) + + return func(self, *args, **kwargs) + return wrapper \ No newline at end of file diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index a05ba1cba..fdaf5405e 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -132,3 +132,36 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs): assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" assert mock_upload_folder_kwargs["folder_path"] == "my_fake_local_model_path" assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456" + mock_upload_folder_kwargs.clear() + + # Incorrect usage: Using deprecated "repo_name" positional argument + caplog.clear() + with caplog.at_level(logging.WARNING): + url = model.save_to_hub(repo_name="sentence-transformers-testing/stsb-bert-tiny-safetensors") + assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456" + assert len(caplog.record_tuples) == 1 + assert ( + caplog.record_tuples[0][2] + == "Providing a `repo_name` keyword argument to `save_to_hub` is deprecated, please use `repo_id` instead." + ) + mock_upload_folder_kwargs.clear() + + # Incorrect usage: Use positional arguments from before "token" was introduced + caplog.clear() + with caplog.at_level(logging.WARNING): + url = model.save_to_hub( + "stsb-bert-tiny-safetensors", # repo_name + "sentence-transformers-testing", # organization + True, # private + "Adding new awesome Model!", # commit message + exist_ok=True, + ) + assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert mock_upload_folder_kwargs["commit_message"] == "Adding new awesome Model!" + assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456" + assert len(caplog.record_tuples) == 1 + assert ( + caplog.record_tuples[0][2] + == 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.' + )