From 72f1e73b2d129b59483ab72e37cce66f8ff58b8d Mon Sep 17 00:00:00 2001 From: Pushpdeep Singh <93198956+iampushpdeep@users.noreply.github.com> Date: Sat, 18 Mar 2023 17:04:10 +0530 Subject: [PATCH] update SentenceTransformer.py Fixed to issue[https://github.com/UKPLab/sentence-transformers/issues/1760] of uploading the model to HuggingFace Hub --- sentence_transformers/SentenceTransformer.py | 25 +++++--------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index e44e573a5..ea666831c 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -430,8 +430,7 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_ fOut.write(model_card.strip()) def save_to_hub(self, - repo_name: str, - organization: Optional[str] = None, + repo_id: str, private: Optional[bool] = None, commit_message: str = "Add new SentenceTransformer model.", local_model_path: Optional[str] = None, @@ -441,8 +440,7 @@ def save_to_hub(self, """ Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository. - :param repo_name: Repository name for your model in the Hub. - :param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization). + :param repo_id: A namespace (user or an organization) and a repo name separated by a /. :param private: Set to true, for hosting a prive model :param commit_message: Message to commit while pushing. :param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded @@ -455,24 +453,13 @@ def save_to_hub(self, if token is None: raise ValueError("You must login to the Hugging Face hub on this computer by typing `transformers-cli login`.") - if '/' in repo_name: - splits = repo_name.split('/', maxsplit=1) - if organization is None or organization == splits[0]: - organization = splits[0] - repo_name = splits[1] - else: - raise ValueError("You passed and invalid repository name: {}.".format(repo_name)) - endpoint = "https://huggingface.co" - repo_id = repo_name - if organization: - repo_id = f"{organization}/{repo_id}" - repo_url = HfApi(endpoint=endpoint).create_repo( - repo_id=repo_id, - token=token, + + repo_url = HfApi(endpoint=endpoint, token=token).create_repo( + repo_id=repo_id, # repo_id : A namespace (user or an organization) and a repo name separated by a /. [https://huggingface.co/docs/huggingface_hub/v0.13.2/en/package_reference/hf_api#huggingface_hub.HfApi.create_repo] private=private, repo_type=None, - exist_ok=exist_ok, + exist_ok=exist_ok ) full_model_name = repo_url[len(endpoint)+1:].strip("/")