Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions pylate/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_model as load_safetensors_model
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.util import import_from_string
Expand Down Expand Up @@ -110,26 +111,49 @@ def from_stanford_weights(
"""
# Check if the model is locally available
if not (os.path.exists(os.path.join(model_name_or_path))):
# Else download the model/use the cached version
model_name_or_path = cached_file(
model_name_or_path,
filename="pytorch_model.bin",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
# If the model a local folder, load the PyTorch model
# Else download the model/use the cached version. We first try to use the safetensors version and fall back to bin if not existing. All the recent stanford-nlp models are safetensors but we keep bin for compatibility.
try:
model_name_or_path = cached_file(
model_name_or_path,
filename="model.safetensors",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
except EnvironmentError:
print("No safetensor model found, falling back to bin.")
model_name_or_path = cached_file(
model_name_or_path,
filename="pytorch_model.bin",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
# If the model a local folder, load the safetensor
# Again, we first try to load the safetensors version and fall back to bin if not existing.
else:
if os.path.exists(os.path.join(model_name_or_path, "model.safetensors")):
model_name_or_path = os.path.join(
model_name_or_path, "model.safetensors"
)
else:
print("No safetensor model found, falling back to bin.")
model_name_or_path = os.path.join(
model_name_or_path, "pytorch_model.bin"
)
if model_name_or_path.endswith("safetensors"):
with safe_open(model_name_or_path, framework="pt", device="cpu") as f:
state_dict = {"linear.weight": f.get_tensor("linear.weight")}
else:
model_name_or_path = os.path.join(model_name_or_path, "pytorch_model.bin")

# Load the state dict using torch.load instead of safe_open
state_dict = {
"linear.weight": torch.load(model_name_or_path, map_location="cpu")[
"linear.weight"
]
}
state_dict = {
"linear.weight": torch.load(model_name_or_path, map_location="cpu")[
"linear.weight"
]
}

# Determine input and output dimensions
in_features = state_dict["linear.weight"].shape[1]
Expand Down
90 changes: 90 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import math

import torch

from pylate import models, rank


def test_model_creation(**kwargs) -> None:
"""Test the creation of different models."""
query = ["fruits are healthy."]
documents = [["fruits are healthy.", "fruits are good for health."]]
torch.manual_seed(42)
# Creation from a base encoder
model = models.ColBERT(model_name_or_path="bert-base-uncased")
# We don't test the embeddings of newly initied models for now as we need to make it deterministic
# queries_embeddings = model.encode(sentences=query, is_query=True)
# documents_embeddings = model.encode(sentences=documents, is_query=False)
# reranked_documents = rank.rerank(
# documents_ids=[["1", "2"]],
# queries_embeddings=queries_embeddings,
# documents_embeddings=documents_embeddings,
# )
# assert math.isclose(
# reranked_documents[0][0]["score"], 25.92, rel_tol=0.01, abs_tol=0.01
# )
# assert math.isclose(reranked_documents[0][1]["score"], 23.7, rel_tol=0.01, abs_tol=0.01)

# Creation from a base sentence-transformer
model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
# We don't test the embeddings of newly initied models for now as we need to make it deterministic
# queries_embeddings = model.encode(sentences=query, is_query=True)
# documents_embeddings = model.encode(sentences=documents, is_query=False)
# reranked_documents = rank.rerank(
# documents_ids=[["1", "2"]],
# queries_embeddings=queries_embeddings,
# documents_embeddings=documents_embeddings,
# )
# assert math.isclose(
# reranked_documents[0][0]["score"], 18.77, rel_tol=0.01, abs_tol=0.01
# )
# assert math.isclose(
# reranked_documents[0][1]["score"], 18.63, rel_tol=0.01, abs_tol=0.01

# Creation from stanford-nlp (safetensor)
model = models.ColBERT(model_name_or_path="answerdotai/answerai-colbert-small-v1")
queries_embeddings = model.encode(sentences=query, is_query=True)
documents_embeddings = model.encode(sentences=documents, is_query=False)
reranked_documents = rank.rerank(
documents_ids=[["1", "2"]],
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
assert math.isclose(
reranked_documents[0][0]["score"], 31.71, rel_tol=0.01, abs_tol=0.01
)
assert math.isclose(
reranked_documents[0][1]["score"], 31.64, rel_tol=0.01, abs_tol=0.01
)

# Creation from stanford-nlp (bin)
model = models.ColBERT(model_name_or_path="Crystalcareai/Colbertv2")
queries_embeddings = model.encode(sentences=query, is_query=True)
documents_embeddings = model.encode(sentences=documents, is_query=False)
reranked_documents = rank.rerank(
documents_ids=[["1", "2"]],
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
assert math.isclose(
reranked_documents[0][0]["score"], 31.15, rel_tol=0.01, abs_tol=0.01
)
assert math.isclose(
reranked_documents[0][1]["score"], 30.61, rel_tol=0.01, abs_tol=0.01
)

# Creation from PyLate
model = models.ColBERT(model_name_or_path="lightonai/colbertv2.0")
queries_embeddings = model.encode(sentences=query, is_query=True)
documents_embeddings = model.encode(sentences=documents, is_query=False)
reranked_documents = rank.rerank(
documents_ids=[["1", "2"]],
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
assert math.isclose(
reranked_documents[0][0]["score"], 30.01, rel_tol=0.01, abs_tol=0.01
)
assert math.isclose(
reranked_documents[0][1]["score"], 26.98, rel_tol=0.01, abs_tol=0.01
)