Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class SentenceTransformer(nn.Sequential):
:param device: Device (like "cuda", "cpu", "mps") that should be used for computation. If None, checks if a GPU
can be used.
:param cache_folder: Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
:param revision: The specific model version to use. It can be a branch name, a tag name, or a commit id,
for a stored model on Hugging Face.
:param trust_remote_code: Whether or not to allow for custom models defined on the Hub in their own modeling files.
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
Expand All @@ -83,6 +85,7 @@ def __init__(
device: Optional[str] = None,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
revision: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
Expand Down Expand Up @@ -187,13 +190,21 @@ def __init__(
# A model from sentence-transformers
model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path

if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder):
if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder, revision=revision):
modules = self._load_sbert_model(
model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code
model_name_or_path,
token=token,
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
)
else:
modules = self._load_auto_model(
model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code
model_name_or_path,
token=token,
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
)

if modules is not None and not isinstance(modules, OrderedDict):
Expand Down Expand Up @@ -942,6 +953,7 @@ def _load_auto_model(
model_name_or_path: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
trust_remote_code: bool = False,
):
"""
Expand All @@ -955,8 +967,8 @@ def _load_auto_model(
transformer_model = Transformer(
model_name_or_path,
cache_dir=cache_folder,
model_args={"token": token, "trust_remote_code": trust_remote_code},
tokenizer_args={"token": token, "trust_remote_code": trust_remote_code},
model_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
tokenizer_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean")
return [transformer_model, pooling_model]
Expand All @@ -966,14 +978,19 @@ def _load_sbert_model(
model_name_or_path: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
trust_remote_code: bool = False,
):
"""
Loads a full sentence-transformers model
"""
# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = load_file_path(
model_name_or_path, "config_sentence_transformers.json", token=token, cache_folder=cache_folder
model_name_or_path,
"config_sentence_transformers.json",
token=token,
cache_folder=cache_folder,
revision=revision,
)
if config_sentence_transformers_json_path is not None:
with open(config_sentence_transformers_json_path) as fIn:
Expand All @@ -991,7 +1008,9 @@ def _load_sbert_model(
)

# Check if a readme exists
model_card_path = load_file_path(model_name_or_path, "README.md", token=token, cache_folder=cache_folder)
model_card_path = load_file_path(
model_name_or_path, "README.md", token=token, cache_folder=cache_folder, revision=revision
)
if model_card_path is not None:
try:
with open(model_card_path, encoding="utf8") as fIn:
Expand All @@ -1000,7 +1019,9 @@ def _load_sbert_model(
pass

# Load the modules of sentence transformer
modules_json_path = load_file_path(model_name_or_path, "modules.json", token=token, cache_folder=cache_folder)
modules_json_path = load_file_path(
model_name_or_path, "modules.json", token=token, cache_folder=cache_folder, revision=revision
)
with open(modules_json_path) as fIn:
modules_config = json.load(fIn)

Expand All @@ -1021,24 +1042,29 @@ def _load_sbert_model(
"sentence_xlnet_config.json",
]:
config_path = load_file_path(
model_name_or_path, config_name, token=token, cache_folder=cache_folder
model_name_or_path, config_name, token=token, cache_folder=cache_folder, revision=revision
)
if config_path is not None:
with open(config_path) as fIn:
kwargs = json.load(fIn)
break
hub_kwargs = {"token": token, "trust_remote_code": trust_remote_code, "revision": revision}
if "model_args" in kwargs:
kwargs["model_args"].update({"token": token, "trust_remote_code": trust_remote_code})
kwargs["model_args"].update(hub_kwargs)
else:
kwargs["model_args"] = {"token": token, "trust_remote_code": trust_remote_code}
kwargs["model_args"] = hub_kwargs
if "tokenizer_args" in kwargs:
kwargs["tokenizer_args"].update({"token": token, "trust_remote_code": trust_remote_code})
kwargs["tokenizer_args"].update(hub_kwargs)
else:
kwargs["tokenizer_args"] = {"token": token, "trust_remote_code": trust_remote_code}
kwargs["tokenizer_args"] = hub_kwargs
module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
else:
module_path = load_dir_path(
model_name_or_path, module_config["path"], token=token, cache_folder=cache_folder
model_name_or_path,
module_config["path"],
token=token,
cache_folder=cache_folder,
revision=revision,
)
module = module_class.load(module_path)
modules[module_config["name"]] = module
Expand Down
21 changes: 17 additions & 4 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,13 +473,20 @@ def __delattr__(self, attr: str) -> None:


def is_sentence_transformer_model(
model_name_or_path: str, token: Optional[Union[bool, str]] = None, cache_folder: Optional[str] = None
model_name_or_path: str,
token: Optional[Union[bool, str]] = None,
cache_folder: Optional[str] = None,
revision: Optional[str] = None,
) -> bool:
return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder))
return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder, revision=revision))


def load_file_path(
model_name_or_path: str, filename: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]
model_name_or_path: str,
filename: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
) -> Optional[str]:
# If file is local
file_path = os.path.join(model_name_or_path, filename)
Expand All @@ -491,6 +498,7 @@ def load_file_path(
return hf_hub_download(
model_name_or_path,
filename=filename,
revision=revision,
library_name="sentence-transformers",
token=token,
cache_dir=cache_folder,
Expand All @@ -500,7 +508,11 @@ def load_file_path(


def load_dir_path(
model_name_or_path: str, directory: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]
model_name_or_path: str,
directory: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
) -> Optional[str]:
# If file is local
dir_path = os.path.join(model_name_or_path, directory)
Expand All @@ -509,6 +521,7 @@ def load_dir_path(

download_kwargs = {
"repo_id": model_name_or_path,
"revision": revision,
"allow_patterns": f"{directory}/**",
"library_name": "sentence-transformers",
"token": token,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,18 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs):
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)


def test_load_with_revision() -> None:
main_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="main")
latest_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="f3cb857cba53019a20df283396bcca179cf051a4"
)
older_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="ba33022fdf0b0fc2643263f0726f44d0a07d0e24"
)

test_sentence = ["Hello there!"]
main_embeddings = main_model.encode(test_sentence, convert_to_tensor=True)
assert torch.equal(main_embeddings, latest_model.encode(test_sentence, convert_to_tensor=True))
assert not torch.equal(main_embeddings, older_model.encode(test_sentence, convert_to_tensor=True))