Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 89 additions & 53 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import stat
from collections import OrderedDict
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional
from dataclasses import asdict, dataclass, field
import requests
import numpy as np
from numpy import ndarray
import transformers
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_url, cached_download
from huggingface_hub import HfApi, HfFolder, Repository, create_repo, hf_hub_url, cached_download
import torch
from torch import nn, Tensor, device
from torch.optim import Optimizer
Expand All @@ -30,6 +31,31 @@

logger = logging.getLogger(__name__)


@dataclass
class HubTrainingArguments:
repo_name: str = field()
organization: Optional[str] = field(default=None)
private: Optional[bool] = field(default=None, metadata={"help": "Specifies the model repo's visibility"})
local_model_path: Optional[str] = field(default=None, metadata={"help": "Local path to a model checkpoint"})
replace_model_card: Optional[bool] = field(default=None, metadata={"help": "Controls whether the existing model card is replaced with a new one"})
train_datasets: Optional[List[str]] = field(default=None, metadata={"help": "List of datasets used for training, included in the model card"})
local_repo_path: Optional[str] = field(default=None, metadata={"help": "Local path to use to initialize a repo"})

@property
def parameters(self):
return {
"repo_name": self.repo_name,
"organization": self.organization,
"private": self.private,
"local_model_path": self.local_model_path,
"replace_model_card": self.replace_model_card,
"train_datasets": self.train_datasets,
"local_repo_path": self.local_repo_path,
"exist_ok": True,
}


class SentenceTransformer(nn.Sequential):
"""
Loads or create a SentenceTransformer model, that can be used to map sentences / text to embeddings.
Expand All @@ -44,7 +70,7 @@ def __init__(self, model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None
use_auth_token: Union[bool, str, None] = None,
):
self._model_card_vars = {}
self._model_card_text = None
Expand Down Expand Up @@ -437,7 +463,8 @@ def save_to_hub(self,
local_model_path: Optional[str] = None,
exist_ok: bool = False,
replace_model_card: bool = False,
train_datasets: Optional[List[str]] = None):
train_datasets: Optional[List[str]] = None,
local_repo_path: Optional[str] = "./hub"):
"""
Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.

Expand All @@ -449,6 +476,7 @@ def save_to_hub(self,
:param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
:param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card
:param train_datasets: Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
:param local_repo_path: Local path where the model repo will be stored, "./hub" by default. Can be set to None to use a temporary directory.
:return: The url of the commit of your model in the given repository.
"""
token = HfFolder.get_token()
Expand All @@ -474,53 +502,50 @@ def save_to_hub(self,
)
full_model_name = repo_url[len(endpoint)+1:].strip("/")

with tempfile.TemporaryDirectory() as tmp_dir:
# First create the repo (and clone its content if it's nonempty).
logger.info("Create repository and clone it if it exists")
repo = Repository(tmp_dir, clone_from=repo_url)

# If user provides local files, copy them.
if local_model_path:
copy_tree(local_model_path, tmp_dir)
else: # Else, save model directly into local repo.
create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md'))
self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card, train_datasets=train_datasets)

#Find files larger 5M and track with git-lfs
large_files = []
for root, dirs, files in os.walk(tmp_dir):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, tmp_dir)

if os.path.getsize(file_path) > (5 * 1024 * 1024):
large_files.append(rel_path)

if len(large_files) > 0:
logger.info("Track files with git lfs: {}".format(", ".join(large_files)))
repo.lfs_track(large_files)

logger.info("Push model to the hub. This might take a while")
push_return = repo.push_to_hub(commit_message=commit_message)

def on_rm_error(func, path, exc_info):
# path contains the path of the file that couldn't be removed
# let's just assume that it's read-only and unlink it.
try:
os.chmod(path, stat.S_IWRITE)
os.unlink(path)
except:
pass
repo_dir = local_repo_path

if local_repo_path is None:
repo_dir = tempfile.mkdtemp()

logger.info("Clone the repository")
repo = Repository(repo_dir, clone_from=repo_url)

# If user provides local files, copy them.
if local_model_path:
copy_tree(local_model_path, repo_dir)
else: # Else, save model directly into local repo.
create_model_card = replace_model_card or not os.path.exists(os.path.join(repo_dir, 'README.md'))
self.save(repo_dir, model_name=full_model_name, create_model_card=create_model_card,
train_datasets=train_datasets)

# Find files larger 5M and track with git-lfs
large_files = []
for root, dirs, files in os.walk(repo_dir):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, repo_dir)

# Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted
# Hence, try to set write permissions on error
if os.path.getsize(file_path) > (5 * 1024 * 1024):
large_files.append(rel_path)

if len(large_files) > 0:
logger.info("Track files with git lfs: {}".format(", ".join(large_files)))
repo.lfs_track(large_files)

logger.info("Push model to the hub. This might take a while")
push_return = repo.push_to_hub(commit_message=commit_message)

def on_rm_error(func, path, exc_info):
# path contains the path of the file that couldn't be removed
# let's just assume that it's read-only and unlink it.
try:
for f in os.listdir(tmp_dir):
shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error)
except Exception as e:
logger.warning("Error when deleting temp folder: {}".format(str(e)))
os.chmod(path, stat.S_IWRITE)
os.unlink(path)
except:
pass

if local_repo_path is None:
shutil.rmtree(repo_dir)

return push_return

Expand Down Expand Up @@ -589,7 +614,9 @@ def fit(self,
show_progress_bar: bool = True,
checkpoint_path: str = None,
checkpoint_save_steps: int = 500,
checkpoint_save_total_limit: int = 0
checkpoint_save_total_limit: int = 0,
push_checkpoints_to_hub: bool = False,
hub_training_arguments: HubTrainingArguments = None,
):
"""
Train the model with the given training objective
Expand Down Expand Up @@ -618,8 +645,15 @@ def fit(self,
:param checkpoint_path: Folder to save checkpoints during training
:param checkpoint_save_steps: Will save a checkpoint after so many steps
:param checkpoint_save_total_limit: Total number of checkpoints to store
:param push_checkpoints_to_hub: If True, each checkpoint will be pushed to the Hugging Face Hub
:param hub_training_arguments: Parameters used to push checkpoints to the Hugging Face Hub.
"""

if push_checkpoints_to_hub:
if checkpoint_path is None or checkpoint_save_steps is None or hub_training_arguments is None:
raise ValueError(
"You must set checkpoint_path checkpoint_save_steps, and hub_training_arguments in order to push checkpoints to the Hugging Face Hub.")

##Add info to model card
#info_loss_functions = "\n".join(["- {} with {} training examples".format(str(loss), len(dataloader)) for dataloader, loss in train_objectives])
info_loss_functions = []
Expand Down Expand Up @@ -739,18 +773,15 @@ def fit(self,
loss_model.train()

if checkpoint_path is not None and checkpoint_save_steps is not None and checkpoint_save_steps > 0 and global_step % checkpoint_save_steps == 0:
self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)

self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step, push_checkpoints_to_hub, hub_training_arguments)

self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)

if evaluator is None and output_path is not None: #No evaluator, but output path: save final model version
self.save(output_path)

if checkpoint_path is not None:
self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)


self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step, push_checkpoints_to_hub, hub_training_arguments)

def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None):
"""
Expand Down Expand Up @@ -782,7 +813,7 @@ def _eval_during_training(self, evaluator, output_path, save_best_model, epoch,
if save_best_model:
self.save(output_path)

def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step):
def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step, save_to_hub=False, hub_training_arguments=None):
# Store new checkpoint
self.save(os.path.join(checkpoint_path, str(step)))

Expand All @@ -797,6 +828,11 @@ def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step):
old_checkpoints = sorted(old_checkpoints, key=lambda x: x['step'])
shutil.rmtree(old_checkpoints[0]['path'])

if save_to_hub:
if hub_training_arguments is None:
raise ValueError("hub_training_arguments must be provided in order to save checkpoint to Hugging Face.")

self.save_to_hub(**hub_training_arguments.parameters, commit_message="Push new training checkpoint.")

def _load_auto_model(self, model_name_or_path):
"""
Expand Down