Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,33 +68,28 @@ def __init__(self, model_name_or_path: Optional[str] = None,
#Old models that don't belong to any organization
basic_transformer_models = ['albert-base-v1', 'albert-base-v2', 'albert-large-v1', 'albert-large-v2', 'albert-xlarge-v1', 'albert-xlarge-v2', 'albert-xxlarge-v1', 'albert-xxlarge-v2', 'bert-base-cased-finetuned-mrpc', 'bert-base-cased', 'bert-base-chinese', 'bert-base-german-cased', 'bert-base-german-dbmdz-cased', 'bert-base-german-dbmdz-uncased', 'bert-base-multilingual-cased', 'bert-base-multilingual-uncased', 'bert-base-uncased', 'bert-large-cased-whole-word-masking-finetuned-squad', 'bert-large-cased-whole-word-masking', 'bert-large-cased', 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-uncased-whole-word-masking', 'bert-large-uncased', 'camembert-base', 'ctrl', 'distilbert-base-cased-distilled-squad', 'distilbert-base-cased', 'distilbert-base-german-cased', 'distilbert-base-multilingual-cased', 'distilbert-base-uncased-distilled-squad', 'distilbert-base-uncased-finetuned-sst-2-english', 'distilbert-base-uncased', 'distilgpt2', 'distilroberta-base', 'gpt2-large', 'gpt2-medium', 'gpt2-xl', 'gpt2', 'openai-gpt', 'roberta-base-openai-detector', 'roberta-base', 'roberta-large-mnli', 'roberta-large-openai-detector', 'roberta-large', 't5-11b', 't5-3b', 't5-base', 't5-large', 't5-small', 'transfo-xl-wt103', 'xlm-clm-ende-1024', 'xlm-clm-enfr-1024', 'xlm-mlm-100-1280', 'xlm-mlm-17-1280', 'xlm-mlm-en-2048', 'xlm-mlm-ende-1024', 'xlm-mlm-enfr-1024', 'xlm-mlm-enro-1024', 'xlm-mlm-tlm-xnli15-1024', 'xlm-mlm-xnli15-1024', 'xlm-roberta-base', 'xlm-roberta-large-finetuned-conll02-dutch', 'xlm-roberta-large-finetuned-conll02-spanish', 'xlm-roberta-large-finetuned-conll03-english', 'xlm-roberta-large-finetuned-conll03-german', 'xlm-roberta-large', 'xlnet-base-cased', 'xlnet-large-cased']

if os.path.exists(model_name_or_path):
#Load from path
model_path = model_name_or_path
else:
#Not a path, load from hub
if not os.path.exists(model_name_or_path):
# Not a path, load from hub
if '\\' in model_name_or_path or model_name_or_path.count('/') > 1:
raise ValueError("Path {} not found".format(model_name_or_path))

if '/' not in model_name_or_path and model_name_or_path.lower() not in basic_transformer_models:
# A model from sentence-transformers
model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path

model_path = os.path.join(cache_folder, model_name_or_path.replace("/", "_"))

if not os.path.exists(os.path.join(model_path, 'modules.json')):
if not os.path.exists(os.path.join(model_name_or_path, 'modules.json')):
# Download from hub with caching
snapshot_download(model_name_or_path,
model_name_or_path = snapshot_download(model_name_or_path,
cache_dir=cache_folder,
library_name='sentence-transformers',
library_version=__version__,
ignore_files=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'],
use_auth_token=use_auth_token)

if os.path.exists(os.path.join(model_path, 'modules.json')): #Load as SentenceTransformer model
modules = self._load_sbert_model(model_path)
if os.path.exists(os.path.join(model_name_or_path, 'modules.json')): #Load as SentenceTransformer model
modules = self._load_sbert_model(model_name_or_path)
else: #Load with AutoModel
modules = self._load_auto_model(model_path)
modules = self._load_auto_model(model_name_or_path)

if modules is not None and not isinstance(modules, OrderedDict):
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
Expand Down
40 changes: 14 additions & 26 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import huggingface_hub
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub import HfApi, hf_hub_url, cached_download, HfFolder
from huggingface_hub import HfApi, hf_hub_url, hf_hub_download, HfFolder
import fnmatch
from packaging import version
import heapq
Expand Down Expand Up @@ -449,10 +449,6 @@ def snapshot_download(

model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)

storage_folder = os.path.join(
cache_dir, repo_id.replace("/", "_")
)

all_files = model_info.siblings
#Download modules.json as the last file
for idx, repofile in enumerate(all_files):
Expand All @@ -461,6 +457,8 @@ def snapshot_download(
all_files.append(repofile)
break

model_storage_folder = None

for model_file in all_files:
if ignore_files is not None:
skip_download = False
Expand All @@ -472,33 +470,23 @@ def snapshot_download(
if skip_download:
continue

url = hf_hub_url(
repo_id, filename=model_file.rfilename, revision=model_info.sha
)
relative_filepath = os.path.join(*model_file.rfilename.split("/"))

# Create potential nested dir
nested_dirname = os.path.dirname(
os.path.join(storage_folder, relative_filepath)
)
os.makedirs(nested_dirname, exist_ok=True)

cached_download_args = {'url': url,
'cache_dir': storage_folder,
'force_filename': relative_filepath,
cached_download_args = {
'repo_id': repo_id,
'filename': model_file.rfilename,
'revision': model_info.sha,
'cache_dir': cache_dir,
'library_name': library_name,
'library_version': library_version,
'user_agent': user_agent,
'use_auth_token': use_auth_token}
'token': use_auth_token}

if version.parse(huggingface_hub.__version__) >= version.parse("0.8.1"):
# huggingface_hub v0.8.1 introduces a new cache layout. We sill use a manual layout
# And need to pass legacy_cache_layout=True to avoid that a warning will be printed
cached_download_args['legacy_cache_layout'] = True
path = hf_hub_download(**cached_download_args)

path = cached_download(**cached_download_args)
# model path is the directory where the modules.json file is stored
if model_file.rfilename == "modules.json":
model_storage_folder = os.path.dirname(path)

if os.path.exists(path + ".lock"):
os.remove(path + ".lock")

return storage_folder
return model_storage_folder