Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,7 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_
fOut.write(model_card.strip())

def save_to_hub(self,
repo_name: str,
organization: Optional[str] = None,
repo_id: str,
private: Optional[bool] = None,
commit_message: str = "Add new SentenceTransformer model.",
local_model_path: Optional[str] = None,
Expand All @@ -441,8 +440,7 @@ def save_to_hub(self,
"""
Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.

:param repo_name: Repository name for your model in the Hub.
:param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization).
:param repo_id: A namespace (user or an organization) and a repo name separated by a /.
:param private: Set to true, for hosting a prive model
:param commit_message: Message to commit while pushing.
:param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
Expand All @@ -455,24 +453,13 @@ def save_to_hub(self,
if token is None:
raise ValueError("You must login to the Hugging Face hub on this computer by typing `transformers-cli login`.")

if '/' in repo_name:
splits = repo_name.split('/', maxsplit=1)
if organization is None or organization == splits[0]:
organization = splits[0]
repo_name = splits[1]
else:
raise ValueError("You passed and invalid repository name: {}.".format(repo_name))

endpoint = "https://huggingface.co"
repo_id = repo_name
if organization:
repo_id = f"{organization}/{repo_id}"
repo_url = HfApi(endpoint=endpoint).create_repo(
repo_id=repo_id,
token=token,

repo_url = HfApi(endpoint=endpoint, token=token).create_repo(
repo_id=repo_id, # repo_id : A namespace (user or an organization) and a repo name separated by a /. [https://huggingface.co/docs/huggingface_hub/v0.13.2/en/package_reference/hf_api#huggingface_hub.HfApi.create_repo]
private=private,
repo_type=None,
exist_ok=exist_ok,
exist_ok=exist_ok
)
full_model_name = repo_url[len(endpoint)+1:].strip("/")

Expand Down