diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 5fbae8177..19b821fa0 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -5,11 +5,12 @@ import stat from collections import OrderedDict from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional +from dataclasses import asdict, dataclass, field import requests import numpy as np from numpy import ndarray import transformers -from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_url, cached_download +from huggingface_hub import HfApi, HfFolder, Repository, create_repo, hf_hub_url, cached_download import torch from torch import nn, Tensor, device from torch.optim import Optimizer @@ -30,6 +31,31 @@ logger = logging.getLogger(__name__) + +@dataclass +class HubTrainingArguments: + repo_name: str = field() + organization: Optional[str] = field(default=None) + private: Optional[bool] = field(default=None, metadata={"help": "Specifies the model repo's visibility"}) + local_model_path: Optional[str] = field(default=None, metadata={"help": "Local path to a model checkpoint"}) + replace_model_card: Optional[bool] = field(default=None, metadata={"help": "Controls whether the existing model card is replaced with a new one"}) + train_datasets: Optional[List[str]] = field(default=None, metadata={"help": "List of datasets used for training, included in the model card"}) + local_repo_path: Optional[str] = field(default=None, metadata={"help": "Local path to use to initialize a repo"}) + + @property + def parameters(self): + return { + "repo_name": self.repo_name, + "organization": self.organization, + "private": self.private, + "local_model_path": self.local_model_path, + "replace_model_card": self.replace_model_card, + "train_datasets": self.train_datasets, + "local_repo_path": self.local_repo_path, + "exist_ok": True, + } + + class SentenceTransformer(nn.Sequential): """ Loads or create a SentenceTransformer model, that can be used to map sentences / text to embeddings. @@ -44,7 +70,7 @@ def __init__(self, model_name_or_path: Optional[str] = None, modules: Optional[Iterable[nn.Module]] = None, device: Optional[str] = None, cache_folder: Optional[str] = None, - use_auth_token: Union[bool, str, None] = None + use_auth_token: Union[bool, str, None] = None, ): self._model_card_vars = {} self._model_card_text = None @@ -437,7 +463,8 @@ def save_to_hub(self, local_model_path: Optional[str] = None, exist_ok: bool = False, replace_model_card: bool = False, - train_datasets: Optional[List[str]] = None): + train_datasets: Optional[List[str]] = None, + local_repo_path: Optional[str] = "./hub"): """ Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository. @@ -449,6 +476,7 @@ def save_to_hub(self, :param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible :param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card :param train_datasets: Datasets used to train the model. If set, the datasets will be added to the model card in the Hub. + :param local_repo_path: Local path where the model repo will be stored, "./hub" by default. Can be set to None to use a temporary directory. :return: The url of the commit of your model in the given repository. """ token = HfFolder.get_token() @@ -474,53 +502,50 @@ def save_to_hub(self, ) full_model_name = repo_url[len(endpoint)+1:].strip("/") - with tempfile.TemporaryDirectory() as tmp_dir: - # First create the repo (and clone its content if it's nonempty). - logger.info("Create repository and clone it if it exists") - repo = Repository(tmp_dir, clone_from=repo_url) - - # If user provides local files, copy them. - if local_model_path: - copy_tree(local_model_path, tmp_dir) - else: # Else, save model directly into local repo. - create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md')) - self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card, train_datasets=train_datasets) - - #Find files larger 5M and track with git-lfs - large_files = [] - for root, dirs, files in os.walk(tmp_dir): - for filename in files: - file_path = os.path.join(root, filename) - rel_path = os.path.relpath(file_path, tmp_dir) - - if os.path.getsize(file_path) > (5 * 1024 * 1024): - large_files.append(rel_path) - - if len(large_files) > 0: - logger.info("Track files with git lfs: {}".format(", ".join(large_files))) - repo.lfs_track(large_files) - - logger.info("Push model to the hub. This might take a while") - push_return = repo.push_to_hub(commit_message=commit_message) - - def on_rm_error(func, path, exc_info): - # path contains the path of the file that couldn't be removed - # let's just assume that it's read-only and unlink it. - try: - os.chmod(path, stat.S_IWRITE) - os.unlink(path) - except: - pass + repo_dir = local_repo_path + + if local_repo_path is None: + repo_dir = tempfile.mkdtemp() + + logger.info("Clone the repository") + repo = Repository(repo_dir, clone_from=repo_url) + + # If user provides local files, copy them. + if local_model_path: + copy_tree(local_model_path, repo_dir) + else: # Else, save model directly into local repo. + create_model_card = replace_model_card or not os.path.exists(os.path.join(repo_dir, 'README.md')) + self.save(repo_dir, model_name=full_model_name, create_model_card=create_model_card, + train_datasets=train_datasets) + + # Find files larger 5M and track with git-lfs + large_files = [] + for root, dirs, files in os.walk(repo_dir): + for filename in files: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, repo_dir) - # Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted - # Hence, try to set write permissions on error + if os.path.getsize(file_path) > (5 * 1024 * 1024): + large_files.append(rel_path) + + if len(large_files) > 0: + logger.info("Track files with git lfs: {}".format(", ".join(large_files))) + repo.lfs_track(large_files) + + logger.info("Push model to the hub. This might take a while") + push_return = repo.push_to_hub(commit_message=commit_message) + + def on_rm_error(func, path, exc_info): + # path contains the path of the file that couldn't be removed + # let's just assume that it's read-only and unlink it. try: - for f in os.listdir(tmp_dir): - shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error) - except Exception as e: - logger.warning("Error when deleting temp folder: {}".format(str(e))) + os.chmod(path, stat.S_IWRITE) + os.unlink(path) + except: pass + if local_repo_path is None: + shutil.rmtree(repo_dir) return push_return @@ -589,7 +614,9 @@ def fit(self, show_progress_bar: bool = True, checkpoint_path: str = None, checkpoint_save_steps: int = 500, - checkpoint_save_total_limit: int = 0 + checkpoint_save_total_limit: int = 0, + push_checkpoints_to_hub: bool = False, + hub_training_arguments: HubTrainingArguments = None, ): """ Train the model with the given training objective @@ -618,8 +645,15 @@ def fit(self, :param checkpoint_path: Folder to save checkpoints during training :param checkpoint_save_steps: Will save a checkpoint after so many steps :param checkpoint_save_total_limit: Total number of checkpoints to store + :param push_checkpoints_to_hub: If True, each checkpoint will be pushed to the Hugging Face Hub + :param hub_training_arguments: Parameters used to push checkpoints to the Hugging Face Hub. """ + if push_checkpoints_to_hub: + if checkpoint_path is None or checkpoint_save_steps is None or hub_training_arguments is None: + raise ValueError( + "You must set checkpoint_path checkpoint_save_steps, and hub_training_arguments in order to push checkpoints to the Hugging Face Hub.") + ##Add info to model card #info_loss_functions = "\n".join(["- {} with {} training examples".format(str(loss), len(dataloader)) for dataloader, loss in train_objectives]) info_loss_functions = [] @@ -739,8 +773,7 @@ def fit(self, loss_model.train() if checkpoint_path is not None and checkpoint_save_steps is not None and checkpoint_save_steps > 0 and global_step % checkpoint_save_steps == 0: - self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) - + self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step, push_checkpoints_to_hub, hub_training_arguments) self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback) @@ -748,9 +781,7 @@ def fit(self, self.save(output_path) if checkpoint_path is not None: - self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) - - + self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step, push_checkpoints_to_hub, hub_training_arguments) def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None): """ @@ -782,7 +813,7 @@ def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, if save_best_model: self.save(output_path) - def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step): + def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step, save_to_hub=False, hub_training_arguments=None): # Store new checkpoint self.save(os.path.join(checkpoint_path, str(step))) @@ -797,6 +828,11 @@ def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step): old_checkpoints = sorted(old_checkpoints, key=lambda x: x['step']) shutil.rmtree(old_checkpoints[0]['path']) + if save_to_hub: + if hub_training_arguments is None: + raise ValueError("hub_training_arguments must be provided in order to save checkpoint to Hugging Face.") + + self.save_to_hub(**hub_training_arguments.parameters, commit_message="Push new training checkpoint.") def _load_auto_model(self, model_name_or_path): """