Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from . import __MODEL_HUB_ORGANIZATION__
from .evaluation import SentenceEvaluator
from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path
from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path, save_to_hub_args_decorator
from .models import Transformer, Pooling
from .model_card_templates import ModelCardTemplate
from . import __version__
Expand Down Expand Up @@ -471,6 +471,7 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_
with open(os.path.join(path, "README.md"), "w", encoding='utf8') as fOut:
fOut.write(model_card.strip())

@save_to_hub_args_decorator
def save_to_hub(self,
repo_id: str,
organization: Optional[str] = None,
Expand Down
20 changes: 20 additions & 0 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import requests
from torch import Tensor, device
from typing import List, Callable
Expand Down Expand Up @@ -483,3 +484,22 @@ def load_dir_path(model_name_or_path: str, directory: str, token: Optional[Union
download_kwargs["local_files_only"] = True
repo_path = snapshot_download(**download_kwargs)
return os.path.join(repo_path, directory)


def save_to_hub_args_decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# If repo_id not already set, use repo_name
repo_name = kwargs.pop("repo_name", None)
if repo_name and "repo_id" not in kwargs:
logger.warning(
"Providing a `repo_name` keyword argument to `save_to_hub` is deprecated, please use `repo_id` instead."
)
kwargs["repo_id"] = repo_name

# If positional args are used, adjust for the new "token" keyword argument
if len(args) >= 2:
args = (*args[:2], None, *args[2:])

return func(self, *args, **kwargs)
return wrapper
33 changes: 33 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,36 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs):
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert mock_upload_folder_kwargs["folder_path"] == "my_fake_local_model_path"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
mock_upload_folder_kwargs.clear()

# Incorrect usage: Using deprecated "repo_name" positional argument
caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub(repo_name="sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== "Providing a `repo_name` keyword argument to `save_to_hub` is deprecated, please use `repo_id` instead."
)
mock_upload_folder_kwargs.clear()

# Incorrect usage: Use positional arguments from before "token" was introduced
caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub(
"stsb-bert-tiny-safetensors", # repo_name
"sentence-transformers-testing", # organization
True, # private
"Adding new awesome Model!", # commit message
exist_ok=True,
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert mock_upload_folder_kwargs["commit_message"] == "Adding new awesome Model!"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)