diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 537d4fee9..0b2bcc3f3 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -70,6 +70,8 @@ class SentenceTransformer(nn.Sequential): :param device: Device (like "cuda", "cpu", "mps") that should be used for computation. If None, checks if a GPU can be used. :param cache_folder: Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable. + :param revision: The specific model version to use. It can be a branch name, a tag name, or a commit id, + for a stored model on Hugging Face. :param trust_remote_code: Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. @@ -83,6 +85,7 @@ def __init__( device: Optional[str] = None, cache_folder: Optional[str] = None, trust_remote_code: bool = False, + revision: Optional[str] = None, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, ): @@ -187,13 +190,21 @@ def __init__( # A model from sentence-transformers model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path - if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder): + if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder, revision=revision): modules = self._load_sbert_model( - model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code + model_name_or_path, + token=token, + cache_folder=cache_folder, + revision=revision, + trust_remote_code=trust_remote_code, ) else: modules = self._load_auto_model( - model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code + model_name_or_path, + token=token, + cache_folder=cache_folder, + revision=revision, + trust_remote_code=trust_remote_code, ) if modules is not None and not isinstance(modules, OrderedDict): @@ -942,6 +953,7 @@ def _load_auto_model( model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str], + revision: Optional[str] = None, trust_remote_code: bool = False, ): """ @@ -955,8 +967,8 @@ def _load_auto_model( transformer_model = Transformer( model_name_or_path, cache_dir=cache_folder, - model_args={"token": token, "trust_remote_code": trust_remote_code}, - tokenizer_args={"token": token, "trust_remote_code": trust_remote_code}, + model_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision}, + tokenizer_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision}, ) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean") return [transformer_model, pooling_model] @@ -966,6 +978,7 @@ def _load_sbert_model( model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str], + revision: Optional[str] = None, trust_remote_code: bool = False, ): """ @@ -973,7 +986,11 @@ def _load_sbert_model( """ # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework) config_sentence_transformers_json_path = load_file_path( - model_name_or_path, "config_sentence_transformers.json", token=token, cache_folder=cache_folder + model_name_or_path, + "config_sentence_transformers.json", + token=token, + cache_folder=cache_folder, + revision=revision, ) if config_sentence_transformers_json_path is not None: with open(config_sentence_transformers_json_path) as fIn: @@ -991,7 +1008,9 @@ def _load_sbert_model( ) # Check if a readme exists - model_card_path = load_file_path(model_name_or_path, "README.md", token=token, cache_folder=cache_folder) + model_card_path = load_file_path( + model_name_or_path, "README.md", token=token, cache_folder=cache_folder, revision=revision + ) if model_card_path is not None: try: with open(model_card_path, encoding="utf8") as fIn: @@ -1000,7 +1019,9 @@ def _load_sbert_model( pass # Load the modules of sentence transformer - modules_json_path = load_file_path(model_name_or_path, "modules.json", token=token, cache_folder=cache_folder) + modules_json_path = load_file_path( + model_name_or_path, "modules.json", token=token, cache_folder=cache_folder, revision=revision + ) with open(modules_json_path) as fIn: modules_config = json.load(fIn) @@ -1021,24 +1042,29 @@ def _load_sbert_model( "sentence_xlnet_config.json", ]: config_path = load_file_path( - model_name_or_path, config_name, token=token, cache_folder=cache_folder + model_name_or_path, config_name, token=token, cache_folder=cache_folder, revision=revision ) if config_path is not None: with open(config_path) as fIn: kwargs = json.load(fIn) break + hub_kwargs = {"token": token, "trust_remote_code": trust_remote_code, "revision": revision} if "model_args" in kwargs: - kwargs["model_args"].update({"token": token, "trust_remote_code": trust_remote_code}) + kwargs["model_args"].update(hub_kwargs) else: - kwargs["model_args"] = {"token": token, "trust_remote_code": trust_remote_code} + kwargs["model_args"] = hub_kwargs if "tokenizer_args" in kwargs: - kwargs["tokenizer_args"].update({"token": token, "trust_remote_code": trust_remote_code}) + kwargs["tokenizer_args"].update(hub_kwargs) else: - kwargs["tokenizer_args"] = {"token": token, "trust_remote_code": trust_remote_code} + kwargs["tokenizer_args"] = hub_kwargs module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs) else: module_path = load_dir_path( - model_name_or_path, module_config["path"], token=token, cache_folder=cache_folder + model_name_or_path, + module_config["path"], + token=token, + cache_folder=cache_folder, + revision=revision, ) module = module_class.load(module_path) modules[module_config["name"]] = module diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 1f18f8bb8..f1d550a8e 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -473,13 +473,20 @@ def __delattr__(self, attr: str) -> None: def is_sentence_transformer_model( - model_name_or_path: str, token: Optional[Union[bool, str]] = None, cache_folder: Optional[str] = None + model_name_or_path: str, + token: Optional[Union[bool, str]] = None, + cache_folder: Optional[str] = None, + revision: Optional[str] = None, ) -> bool: - return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder)) + return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder, revision=revision)) def load_file_path( - model_name_or_path: str, filename: str, token: Optional[Union[bool, str]], cache_folder: Optional[str] + model_name_or_path: str, + filename: str, + token: Optional[Union[bool, str]], + cache_folder: Optional[str], + revision: Optional[str] = None, ) -> Optional[str]: # If file is local file_path = os.path.join(model_name_or_path, filename) @@ -491,6 +498,7 @@ def load_file_path( return hf_hub_download( model_name_or_path, filename=filename, + revision=revision, library_name="sentence-transformers", token=token, cache_dir=cache_folder, @@ -500,7 +508,11 @@ def load_file_path( def load_dir_path( - model_name_or_path: str, directory: str, token: Optional[Union[bool, str]], cache_folder: Optional[str] + model_name_or_path: str, + directory: str, + token: Optional[Union[bool, str]], + cache_folder: Optional[str], + revision: Optional[str] = None, ) -> Optional[str]: # If file is local dir_path = os.path.join(model_name_or_path, directory) @@ -509,6 +521,7 @@ def load_dir_path( download_kwargs = { "repo_id": model_name_or_path, + "revision": revision, "allow_patterns": f"{directory}/**", "library_name": "sentence-transformers", "token": token, diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 87cd1aaea..e60be4154 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -171,3 +171,18 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs): caplog.record_tuples[0][2] == 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.' ) + + +def test_load_with_revision() -> None: + main_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="main") + latest_model = SentenceTransformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="f3cb857cba53019a20df283396bcca179cf051a4" + ) + older_model = SentenceTransformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="ba33022fdf0b0fc2643263f0726f44d0a07d0e24" + ) + + test_sentence = ["Hello there!"] + main_embeddings = main_model.encode(test_sentence, convert_to_tensor=True) + assert torch.equal(main_embeddings, latest_model.encode(test_sentence, convert_to_tensor=True)) + assert not torch.equal(main_embeddings, older_model.encode(test_sentence, convert_to_tensor=True))