From 73c381c91a268ed10980c45b5a7514234aa3acc2 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Mon, 28 Nov 2022 16:42:02 +0100 Subject: [PATCH] Remove hardcoded HF endpoint --- sentence_transformers/SentenceTransformer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index e44e573a5..3c06677ae 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -443,7 +443,7 @@ def save_to_hub(self, :param repo_name: Repository name for your model in the Hub. :param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization). - :param private: Set to true, for hosting a prive model + :param private: Set to true, for hosting a private model :param commit_message: Message to commit while pushing. :param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded :param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible @@ -463,18 +463,16 @@ def save_to_hub(self, else: raise ValueError("You passed and invalid repository name: {}.".format(repo_name)) - endpoint = "https://huggingface.co" repo_id = repo_name if organization: repo_id = f"{organization}/{repo_id}" - repo_url = HfApi(endpoint=endpoint).create_repo( + repo_url = HfApi().create_repo( repo_id=repo_id, token=token, private=private, repo_type=None, exist_ok=exist_ok, ) - full_model_name = repo_url[len(endpoint)+1:].strip("/") with tempfile.TemporaryDirectory() as tmp_dir: # First create the repo (and clone its content if it's nonempty). @@ -486,7 +484,7 @@ def save_to_hub(self, copy_tree(local_model_path, tmp_dir) else: # Else, save model directly into local repo. create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md')) - self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card, train_datasets=train_datasets) + self.save(tmp_dir, model_name=repo_id, create_model_card=create_model_card, train_datasets=train_datasets) #Find files larger 5M and track with git-lfs large_files = []