Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ class SentenceTransformer(nn.Sequential):
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
:param cache_folder: Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.
:param use_auth_token: HuggingFace authentication token to download private models.
:param model_revision: The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on huggingface.co.
"""
def __init__(self, model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None
use_auth_token: Union[bool, str, None] = None,
model_revision: Optional[str] = None
):
self._model_card_vars = {}
self._model_card_text = None
Expand Down Expand Up @@ -89,7 +91,8 @@ def __init__(self, model_name_or_path: Optional[str] = None,
library_name='sentence-transformers',
library_version=__version__,
ignore_files=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'],
use_auth_token=use_auth_token)
use_auth_token=use_auth_token,
revision=model_revision)

if os.path.exists(os.path.join(model_path, 'modules.json')): #Load as SentenceTransformer model
modules = self._load_sbert_model(model_path)
Expand Down