diff --git a/pylate/models/Dense.py b/pylate/models/Dense.py index b9f35754..2ae1783b 100644 --- a/pylate/models/Dense.py +++ b/pylate/models/Dense.py @@ -2,7 +2,6 @@ import os import torch -from safetensors import safe_open from safetensors.torch import load_model as load_safetensors_model from sentence_transformers.models import Dense as DenseSentenceTransformer from sentence_transformers.util import import_from_string @@ -114,18 +113,23 @@ def from_stanford_weights( # Else download the model/use the cached version model_name_or_path = cached_file( model_name_or_path, - filename="model.safetensors", + filename="pytorch_model.bin", cache_dir=cache_folder, revision=revision, local_files_only=local_files_only, token=token, use_auth_token=use_auth_token, ) - # If the model a local folder, load the safetensor + # If the model a local folder, load the PyTorch model else: - model_name_or_path = os.path.join(model_name_or_path, "model.safetensors") - with safe_open(model_name_or_path, framework="pt", device="cpu") as f: - state_dict = {"linear.weight": f.get_tensor("linear.weight")} + model_name_or_path = os.path.join(model_name_or_path, "pytorch_model.bin") + + # Load the state dict using torch.load instead of safe_open + state_dict = { + "linear.weight": torch.load(model_name_or_path, map_location="cpu")[ + "linear.weight" + ] + } # Determine input and output dimensions in_features = state_dict["linear.weight"].shape[1] diff --git a/setup.py b/setup.py index 75b2e748..faa934ae 100644 --- a/setup.py +++ b/setup.py @@ -6,12 +6,13 @@ long_description = fh.read() base_packages = [ - "sentence-transformers == 3.2.0", + "sentence-transformers == 3.3.0", "datasets >= 2.20.0", "accelerate >= 0.31.0", "voyager >= 2.0.9", "sqlitedict >= 2.1.0", "pandas >= 2.2.1", + "transformers == 4.46.2", ]