diff --git a/Makefile b/Makefile index ca252b158a..e40c191a73 100644 --- a/Makefile +++ b/Makefile @@ -35,9 +35,8 @@ pr: make lint make test -build-docs: +build-docs: build-docs-overview @echo "--- 📚 Building documentation ---" - make build-docs-overview python -m mkdocs build diff --git a/docs/advanced_usage/retrieval_backend.md b/docs/advanced_usage/retrieval_backend.md new file mode 100644 index 0000000000..fb8f964983 --- /dev/null +++ b/docs/advanced_usage/retrieval_backend.md @@ -0,0 +1,23 @@ +# Retrieval Search backend + +!!! note "Available since 2.3.0" + This feature was introduced in version **2.3.0**. + +For some large dataset search can take a lot of time and memory. To reduce this you can use `FaissSearchIndex`. To work with it install `pip install mteb[faiss]`. + +Usage example: +```python +import mteb +from mteb.models import SearchEncoderWrapper +from mteb.models.search_encoder_index import FaissSearchIndex + +model = mteb.get_model(...) +index_backend = FaissSearchIndex(model) +model = SearchEncoderWrapper( + model, + index_backend=index_backend +) +... +``` + +For example running `minishlab/potion-base-2M` on `SWEbenchVerifiedRR` took 694 seconds instead of 769. diff --git a/docs/api/model.md b/docs/api/model.md index ef57cb2293..038138aca4 100644 --- a/docs/api/model.md +++ b/docs/api/model.md @@ -33,3 +33,7 @@ length, valid frameworks, license, and degree of openness. :::mteb.models.CrossEncoderProtocol :::mteb.models.MTEBModels + +:::mteb.models.IndexEncoderSearchProtocol + +:::mteb.models.CacheBackendProtocol diff --git a/mkdocs.yml b/mkdocs.yml index 9453c64695..182131f50b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -83,6 +83,7 @@ nav: - Advanced Usage: - Two stage reranking: advanced_usage/two_stage_reranking.md - Cache embeddings: advanced_usage/cache_embeddings.md + - Retrieval backend: advanced_usage/retrieval_backend.md - Contributing: - Adding a Model: contributing/adding_a_model.md - Adding a Task: contributing/adding_a_dataset.md diff --git a/mteb/__init__.py b/mteb/__init__.py index e918cb41c0..b6f4142bc7 100644 --- a/mteb/__init__.py +++ b/mteb/__init__.py @@ -9,8 +9,10 @@ from mteb.get_tasks import get_task, get_tasks from mteb.load_results import load_results from mteb.models import ( + CacheBackendProtocol, CrossEncoderProtocol, EncoderProtocol, + IndexEncoderSearchProtocol, SearchProtocol, SentenceTransformerEncoderWrapper, ) @@ -27,8 +29,10 @@ "AbsTask", "Benchmark", "BenchmarkResults", + "CacheBackendProtocol", "CrossEncoderProtocol", "EncoderProtocol", + "IndexEncoderSearchProtocol", "SearchProtocol", "SentenceTransformerEncoderWrapper", "TaskMetadata", diff --git a/mteb/models/__init__.py b/mteb/models/__init__.py index ed59708029..1c70a191e6 100644 --- a/mteb/models/__init__.py +++ b/mteb/models/__init__.py @@ -1,4 +1,4 @@ -from .cache_wrappers import CachedEmbeddingWrapper +from .cache_wrappers import CacheBackendProtocol, CachedEmbeddingWrapper from .model_meta import ModelMeta from .models_protocols import ( CrossEncoderProtocol, @@ -6,6 +6,7 @@ MTEBModels, SearchProtocol, ) +from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol from .search_wrappers import SearchCrossEncoderWrapper, SearchEncoderWrapper from .sentence_transformer_wrapper import ( CrossEncoderWrapper, @@ -14,10 +15,12 @@ ) __all__ = [ + "CacheBackendProtocol", "CachedEmbeddingWrapper", "CrossEncoderProtocol", "CrossEncoderWrapper", "EncoderProtocol", + "IndexEncoderSearchProtocol", "MTEBModels", "ModelMeta", "SearchCrossEncoderWrapper", diff --git a/mteb/models/cache_wrappers/__init__.py b/mteb/models/cache_wrappers/__init__.py index 12b708c75a..efc64515d2 100644 --- a/mteb/models/cache_wrappers/__init__.py +++ b/mteb/models/cache_wrappers/__init__.py @@ -1,3 +1,4 @@ +from .cache_backend_protocol import CacheBackendProtocol from .cache_wrapper import CachedEmbeddingWrapper -__all__ = ["CachedEmbeddingWrapper"] +__all__ = ["CacheBackendProtocol", "CachedEmbeddingWrapper"] diff --git a/mteb/models/model_implementations/random_baseline.py b/mteb/models/model_implementations/random_baseline.py index f8bac508e2..562b54914c 100644 --- a/mteb/models/model_implementations/random_baseline.py +++ b/mteb/models/model_implementations/random_baseline.py @@ -8,6 +8,10 @@ from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta +from mteb.similarity_functions import ( + select_pairwise_similarity, + select_similarity, +) from mteb.types._encoder_io import Array, BatchedInput, PromptType @@ -155,15 +159,9 @@ def similarity( Returns: Cosine similarity matrix between the two sets of embeddings """ - norm1 = np.linalg.norm( - embeddings1.reshape(-1, self.embedding_dim), axis=1, keepdims=True - ) - norm2 = np.linalg.norm( - embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True + return select_similarity( + embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name ) - normalized1 = embeddings1 / (norm1 + 1e-10) - normalized2 = embeddings2 / (norm2 + 1e-10) - return np.dot(normalized1, normalized2.T) def similarity_pairwise( self, @@ -179,17 +177,9 @@ def similarity_pairwise( Returns: Cosine similarity for each pair of embeddings """ - norm1 = np.linalg.norm( - embeddings1.reshape(-1, self.embedding_dim), axis=1, keepdims=True - ) - norm2 = np.linalg.norm( - embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True + return select_pairwise_similarity( + embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name ) - normalized1 = embeddings1 / (norm1 + 1e-10) - normalized2 = embeddings2 / (norm2 + 1e-10) - normalized1 = np.asarray(normalized1) - normalized2 = np.asarray(normalized2) - return np.sum(normalized1 * normalized2, axis=1) random_encoder_baseline = ModelMeta( diff --git a/mteb/models/search_encoder_index/__init__.py b/mteb/models/search_encoder_index/__init__.py new file mode 100644 index 0000000000..80a3a4d078 --- /dev/null +++ b/mteb/models/search_encoder_index/__init__.py @@ -0,0 +1,7 @@ +from .search_backend_protocol import IndexEncoderSearchProtocol +from .search_indexes import FaissSearchIndex + +__all__ = [ + "FaissSearchIndex", + "IndexEncoderSearchProtocol", +] diff --git a/mteb/models/search_encoder_index/search_backend_protocol.py b/mteb/models/search_encoder_index/search_backend_protocol.py new file mode 100644 index 0000000000..8fd936677f --- /dev/null +++ b/mteb/models/search_encoder_index/search_backend_protocol.py @@ -0,0 +1,50 @@ +from collections.abc import Callable +from typing import Protocol + +from mteb.types import Array, TopRankedDocumentsType + + +class IndexEncoderSearchProtocol(Protocol): + """Protocol for search backends used in encoder-based retrieval.""" + + def add_documents( + self, + embeddings: Array, + idxs: list[str], + ) -> None: + """Add documents to the search backend. + + Args: + embeddings: Embeddings of the documents to add. + idxs: IDs of the documents to add. + """ + + def search( + self, + embeddings: Array, + top_k: int, + similarity_fn: Callable[[Array, Array], Array], + top_ranked: TopRankedDocumentsType | None = None, + query_idx_to_id: dict[int, str] | None = None, + ) -> tuple[list[list[float]], list[list[int]]]: + """Search through added corpus embeddings or rerank top-ranked documents. + + Supports both full-corpus and reranking search modes: + - Full-corpus mode: `top_ranked=None`, uses added corpus embeddings. + - Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}. + + Args: + embeddings: Query embeddings, shape (num_queries, dim). + top_k: Number of top results to return. + similarity_fn: Function to compute similarity between query and corpus. + top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking. + query_idx_to_id: Mapping of query index -> query_id. Used for reranking. + + Returns: + A tuple (top_k_values, top_k_indices), for each query: + - top_k_values: List of top-k similarity scores. + - top_k_indices: List of indices of the top-k documents in the added corpus. + """ + + def clear(self) -> None: + """Clear all stored documents and embeddings from the backend.""" diff --git a/mteb/models/search_encoder_index/search_indexes/__init__.py b/mteb/models/search_encoder_index/search_indexes/__init__.py new file mode 100644 index 0000000000..18187e4774 --- /dev/null +++ b/mteb/models/search_encoder_index/search_indexes/__init__.py @@ -0,0 +1,5 @@ +from .faiss_search_index import FaissSearchIndex + +__all__ = [ + "FaissSearchIndex", +] diff --git a/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py b/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py new file mode 100644 index 0000000000..e254ca7087 --- /dev/null +++ b/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py @@ -0,0 +1,157 @@ +import logging +from collections.abc import Callable + +import numpy as np +import torch + +from mteb._requires_package import requires_package +from mteb.models.model_meta import ScoringFunction +from mteb.models.models_protocols import EncoderProtocol +from mteb.types import Array, TopRankedDocumentsType + +logger = logging.getLogger(__name__) + + +class FaissSearchIndex: + """FAISS-based backend for encoder-based search. + + Supports both full-corpus retrieval and reranking (via `top_ranked`). + + Notes: + - Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2). + - Expects embeddings to be normalized if cosine similarity is desired. + """ + + _normalize: bool = False + + def __init__(self, model: EncoderProtocol) -> None: + requires_package( + self, + "faiss", + "FAISS-based search", + install_instruction="pip install mteb[faiss-cpu]", + ) + + import faiss + from faiss import IndexFlatIP, IndexFlatL2 + + # https://github.com/facebookresearch/faiss/wiki/Faiss-indexes + if model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT: + self.index_type = IndexFlatIP + elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE: + self.index_type = IndexFlatIP + self._normalize = True + elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.EUCLIDEAN: + self.index_type = IndexFlatL2 + else: + raise ValueError( + f"FAISS backend does not support similarity function {model.mteb_model_meta.similarity_fn_name}. " + f"Available: {ScoringFunction.DOT_PRODUCT}, {ScoringFunction.COSINE}." + ) + + self.idxs: list[str] = [] + self.index: faiss.Index | None = None + + def add_documents(self, embeddings: Array, idxs: list[str]) -> None: + """Add all document embeddings and their IDs to FAISS index.""" + import faiss + + if isinstance(embeddings, torch.Tensor): + embeddings = embeddings.detach().cpu().numpy() + + embeddings = embeddings.astype(np.float32) + self.idxs.extend(idxs) + + if self._normalize: + faiss.normalize_L2(embeddings) + + dim = embeddings.shape[1] + if self.index is None: + self.index = self.index_type(dim) + + self.index.add(embeddings) + logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.") + + def search( + self, + embeddings: Array, + top_k: int, + similarity_fn: Callable[[Array, Array], Array], + top_ranked: TopRankedDocumentsType | None = None, + query_idx_to_id: dict[int, str] | None = None, + ) -> tuple[list[list[float]], list[list[int]]]: + """Search using FAISS.""" + import faiss + + if self.index is None: + raise ValueError("No index built. Call add_document() first.") + + if isinstance(embeddings, torch.Tensor): + embeddings = embeddings.detach().cpu().numpy() + + if self._normalize: + faiss.normalize_L2(embeddings) + + if top_ranked is not None: + if query_idx_to_id is None: + raise ValueError("query_idx_to_id must be provided when reranking.") + + similarities, ids = self._reranking( + embeddings, + top_k, + top_ranked=top_ranked, + query_idx_to_id=query_idx_to_id, + ) + else: + similarities, ids = self.index.search(embeddings.astype(np.float32), top_k) + similarities = similarities.tolist() + ids = ids.tolist() + + if issubclass(self.index_type, faiss.IndexFlatL2): + similarities = -np.sqrt(np.maximum(similarities, 0)) + + return similarities, ids + + def _reranking( + self, + embeddings: Array, + top_k: int, + top_ranked: TopRankedDocumentsType | None = None, + query_idx_to_id: dict[int, str] | None = None, + ) -> tuple[list[list[float]], list[list[int]]]: + doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} + scores_all: list[list[float]] = [] + idxs_all: list[list[int]] = [] + + for query_idx, query_emb in enumerate(embeddings): + query_id = query_idx_to_id[query_idx] + ranked_ids = top_ranked.get(query_id) + if not ranked_ids: + logger.warning(f"No top-ranked documents for query {query_id}") + scores_all.append([]) + idxs_all.append([]) + continue + + candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] + d = self.index.d + candidate_embs = np.vstack( + [self.index.reconstruct(idx) for idx in candidate_indices] + ) + sub_reranking_index = self.index_type(d) + sub_reranking_index.add(candidate_embs) + + # Search returns scores and indices in one call + scores, local_indices = sub_reranking_index.search( + query_emb.reshape(1, -1).astype(np.float32), + min(top_k, len(candidate_indices)), + ) + # faiss will output 2d arrays even for single query + scores_all.append(scores[0].tolist()) + idxs_all.append(local_indices[0].tolist()) + + return scores_all, idxs_all + + def clear(self) -> None: + """Clear all stored documents and embeddings from the backend.""" + self.index = None + self.idxs = [] diff --git a/mteb/models/search_wrappers.py b/mteb/models/search_wrappers.py index acec858faa..4627a95bcb 100644 --- a/mteb/models/search_wrappers.py +++ b/mteb/models/search_wrappers.py @@ -21,6 +21,7 @@ ) from .models_protocols import CrossEncoderProtocol, EncoderProtocol +from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol logger = logging.getLogger(__name__) @@ -28,13 +29,19 @@ class SearchEncoderWrapper: """Wrapper for Encoder models to be used in search tasks.""" - corpus_chunk_size = 50_000 task_corpus: CorpusDatasetType | None - def __init__(self, model: EncoderProtocol): + def __init__( + self, + model: EncoderProtocol, + corpus_chunk_size: int = 50_000, + index_backend: IndexEncoderSearchProtocol | None = None, + ) -> None: self.model = model self.task_corpus = None self.mteb_model_meta = model.mteb_model_meta + self.corpus_chunk_size = corpus_chunk_size + self.index_backend = index_backend def index( self, @@ -56,6 +63,22 @@ def index( """ # Always retain corpus for potential reranking or fallback flows self.task_corpus = corpus + if self.index_backend is not None: + all_doc_embeddings = self.model.encode( + create_dataloader( + corpus, + task_metadata, + prompt_type=PromptType.document, + **encode_kwargs, + ), + task_metadata=task_metadata, + hf_split=hf_split, + hf_subset=hf_subset, + prompt_type=PromptType.document, + **encode_kwargs, + ) + + self.index_backend.add_documents(all_doc_embeddings, corpus["id"]) def search( self, @@ -105,27 +128,74 @@ def search( if top_ranked is not None: logger.info("Reranking pre-ranked documents...") - result_heaps = self._rerank_documents( - query_idx_to_id=query_idx_to_id, - query_embeddings=query_embeddings, - top_ranked=top_ranked, - top_k=top_k, - task_metadata=task_metadata, - hf_subset=hf_subset, - hf_split=hf_split, - encode_kwargs=encode_kwargs, - ) + if self.index_backend is None: + result_heaps = self._rerank_documents( + query_idx_to_id=query_idx_to_id, + query_embeddings=query_embeddings, + top_ranked=top_ranked, + top_k=top_k, + task_metadata=task_metadata, + hf_subset=hf_subset, + hf_split=hf_split, + encode_kwargs=encode_kwargs, + ) + else: + cos_scores_top_k_values, cos_scores_top_k_idx = ( + self.index_backend.search( + query_embeddings, + top_k, + similarity_fn=self.model.similarity, + top_ranked=top_ranked, + query_idx_to_id=query_idx_to_id, + ) + ) + result_heaps = {qid: [] for qid in query_idx_to_id.values()} + for query_itr in range(len(query_embeddings)): + result_heaps = self._rerank_sort_results( + result_heaps=result_heaps, + query_id=query_idx_to_id[query_itr], + ranked_ids=top_ranked[query_idx_to_id[query_itr]], + scores_top_k_idx=torch.tensor( + [cos_scores_top_k_idx[query_itr]] + ), + scores_top_k_values=torch.tensor( + [cos_scores_top_k_values[query_itr]] + ), + ) + self.index_backend.clear() else: logger.info("Performing full corpus search...") - result_heaps = self._full_corpus_search( - query_idx_to_id=query_idx_to_id, - query_embeddings=query_embeddings, - task_metadata=task_metadata, - hf_subset=hf_subset, - hf_split=hf_split, - top_k=top_k, - encode_kwargs=encode_kwargs, - ) + if self.index_backend is None: + result_heaps = self._full_corpus_search( + query_idx_to_id=query_idx_to_id, + query_embeddings=query_embeddings, + task_metadata=task_metadata, + hf_subset=hf_subset, + hf_split=hf_split, + top_k=top_k, + encode_kwargs=encode_kwargs, + ) + else: + cos_scores_top_k_values, cos_scores_top_k_idx = ( + self.index_backend.search( + query_embeddings, + top_k, + similarity_fn=self.model.similarity, + top_ranked=None, + query_idx_to_id=None, + ) + ) + result_heaps = {qid: [] for qid in query_idx_to_id.values()} + result_heaps = self._sort_full_corpus_results( + result_heaps=result_heaps, + query_idx_to_id=query_idx_to_id, + query_embeddings=query_embeddings, + cos_scores_top_k_idx=cos_scores_top_k_idx, + cos_scores_top_k_values=cos_scores_top_k_values, + sub_corpus_ids=self.task_corpus["id"], + top_k=top_k, + ) + self.index_backend.clear() # Reset the task corpus dataloader to None to free up memory self.task_corpus = None @@ -192,19 +262,45 @@ def _full_corpus_search( cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() sub_corpus_ids = list(sub_corpus_ids) - for query_itr in range(len(query_embeddings)): - query_id = query_idx_to_id[query_itr] - for sub_corpus_id, score in zip( - cos_scores_top_k_idx[query_itr], - cos_scores_top_k_values[query_itr], - ): - corpus_id = sub_corpus_ids[sub_corpus_id] - if len(result_heaps[query_id]) < top_k: - # push item on the heap - heapq.heappush(result_heaps[query_id], (score, corpus_id)) - else: - # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element - heapq.heappushpop(result_heaps[query_id], (score, corpus_id)) + result_heaps = self._sort_full_corpus_results( + result_heaps=result_heaps, + query_idx_to_id=query_idx_to_id, + query_embeddings=query_embeddings, + cos_scores_top_k_idx=cos_scores_top_k_idx, + cos_scores_top_k_values=cos_scores_top_k_values, + sub_corpus_ids=sub_corpus_ids, + top_k=top_k, + ) + return result_heaps + + def _sort_full_corpus_results( + self, + result_heaps: dict[str, list[tuple[float, str]]], + query_idx_to_id: dict[int, str], + query_embeddings: Array, + cos_scores_top_k_idx: list[list[int]], + cos_scores_top_k_values: list[list[float]], + sub_corpus_ids: list[str], + top_k: int, + ) -> dict[str, list[tuple[float, str]]]: + """Sort the heaps into descending order lists. + + Returns: + A dictionary mapping query IDs to a sorted list of tuples, each containing a relevance score and a document ID. + """ + for query_itr in range(len(query_embeddings)): + query_id = query_idx_to_id[query_itr] + for sub_corpus_id, score in zip( + cos_scores_top_k_idx[query_itr], + cos_scores_top_k_values[query_itr], + ): + corpus_id = sub_corpus_ids[sub_corpus_id] + if len(result_heaps[query_id]) < top_k: + # push item on the heap + heapq.heappush(result_heaps[query_id], (score, corpus_id)) + else: + # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element + heapq.heappushpop(result_heaps[query_id], (score, corpus_id)) return result_heaps def _rerank_documents( @@ -279,14 +375,34 @@ def _rerank_documents( scores_top_k_values = scores_top_k_values.cpu() scores_top_k_idx = scores_top_k_idx.cpu() - # Build result heap - for doc_idx, score in zip( - scores_top_k_idx[0].tolist(), - scores_top_k_values[0].tolist(), - ): - corpus_id = ranked_ids[doc_idx] - heapq.heappush(result_heaps[query_id], (score, corpus_id)) + result_heaps = self._rerank_sort_results( + result_heaps=result_heaps, + query_id=query_id, + ranked_ids=ranked_ids, + scores_top_k_idx=scores_top_k_idx, + scores_top_k_values=scores_top_k_values, + ) + return result_heaps + + def _rerank_sort_results( + self, + result_heaps: list[tuple[float, str]], + query_id: str, + ranked_ids: list[str], + scores_top_k_idx: torch.Tensor, + scores_top_k_values: torch.Tensor, + ) -> list[tuple[float, str]]: + """Sort the heap into descending order list. + Returns: + A sorted list of tuples, each containing a relevance score and a document ID. + """ + for doc_idx, score in zip( + scores_top_k_idx[0].tolist(), + scores_top_k_values[0].tolist(), + ): + corpus_id = ranked_ids[doc_idx] + heapq.heappush(result_heaps[query_id], (score, corpus_id)) return result_heaps def encode( diff --git a/mteb/similarity_functions.py b/mteb/similarity_functions.py index b8e2f5e9ad..1624a034d1 100644 --- a/mteb/similarity_functions.py +++ b/mteb/similarity_functions.py @@ -1,6 +1,7 @@ import torch from mteb.models import EncoderProtocol +from mteb.models.model_meta import ScoringFunction from mteb.types import Array @@ -38,6 +39,54 @@ def compute_pairwise_similarity( return pairwise_cos_sim(embedding1, embedding2) +def select_similarity( + embedding1: Array, + embedding2: Array, + similarity_fn: ScoringFunction, +) -> Array: + """Compute similarity between two sets of embeddings using the specified similarity function. + + Args: + embedding1: The first set of embeddings. + embedding2: The second set of embeddings. + similarity_fn: The similarity function to use (COSINE, DOT_PRODUCT, EUCLIDEAN). + + Returns: + Array: The computed similarity scores. + """ + if similarity_fn is ScoringFunction.COSINE: + return cos_sim(embedding1, embedding2) + elif similarity_fn is ScoringFunction.DOT_PRODUCT: + return dot_score(embedding1, embedding2) + elif similarity_fn is ScoringFunction.EUCLIDEAN: + return euclidean_sim(embedding1, embedding2) + raise ValueError(f"Unsupported similarity function: {similarity_fn}") + + +def select_pairwise_similarity( + embedding1: Array, + embedding2: Array, + similarity_fn: ScoringFunction, +) -> Array: + """Compute pairwise similarity between two sets of embeddings using the specified similarity function. + + Args: + embedding1: The first set of embeddings. + embedding2: The second set of embeddings. + similarity_fn: The similarity function to use (COSINE, DOT_PRODUCT, EUCLIDEAN). + + Returns: + Array: The computed pairwise similarity scores. + """ + if similarity_fn is ScoringFunction.COSINE: + return pairwise_cos_sim(embedding1, embedding2) + elif similarity_fn is ScoringFunction.DOT_PRODUCT: + return pairwise_dot_score(embedding1, embedding2) + elif similarity_fn is ScoringFunction.EUCLIDEAN: + return pairwise_euclidean_sim(embedding1, embedding2) + raise ValueError(f"Unsupported similarity function: {similarity_fn}") + + def _normalize_embeddings(embeddings: Array) -> torch.Tensor: """Normalizes the embeddings matrix, so that each sentence embedding has unit length. diff --git a/tests/test_search_index/__init__.py b/tests/test_search_index/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_search_index/test_search_index.py b/tests/test_search_index/test_search_index.py new file mode 100644 index 0000000000..12d03f2b3c --- /dev/null +++ b/tests/test_search_index/test_search_index.py @@ -0,0 +1,75 @@ +import json +from copy import deepcopy +from pathlib import Path + +import pytest + +import mteb +from mteb.abstasks import AbsTaskRetrieval +from mteb.models import SearchEncoderWrapper +from mteb.models.model_meta import ScoringFunction +from mteb.models.search_encoder_index import FaissSearchIndex +from tests.mock_tasks import ( + MockRerankingTask, + MockRetrievalTask, +) + + +@pytest.mark.parametrize( + "task", + [ + MockRetrievalTask(), + MockRerankingTask(), + ], +) +@pytest.mark.parametrize( + "similarity", + [ScoringFunction.DOT_PRODUCT, ScoringFunction.COSINE, ScoringFunction.EUCLIDEAN], +) +def test_retrieval_backends( + task: AbsTaskRetrieval, similarity: ScoringFunction, tmp_path: Path +): + """Test different retrieval backends for retrieval and reranking tasks.""" + model = mteb.get_model("baseline/random-encoder-baseline") + model_meta = deepcopy(model.mteb_model_meta) + model_meta.similarity_fn_name = similarity + model.mteb_model_meta = model_meta + + faiss_backend = SearchEncoderWrapper(model, index_backend=FaissSearchIndex(model)) + + python_backend_predictions = tmp_path / "python_backend_predictions" + faiss_backend_predictions = tmp_path / "faiss_backend_predictions" + + python_results = mteb.evaluate( + model, + task, + prediction_folder=python_backend_predictions, + cache=None, + ) + faiss_results = mteb.evaluate( + faiss_backend, + task, + prediction_folder=faiss_backend_predictions, + cache=None, + ) + + assert ( + python_results.task_results[0].get_score() + == faiss_results.task_results[0].get_score() + ) + + with task._predictions_path(python_backend_predictions).open() as f: + full_python_predictions = json.load(f) + python_predictions = full_python_predictions["default"]["test"] + + with task._predictions_path(faiss_backend_predictions).open() as f: + full_faiss_predictions = json.load(f) + faiss_predictions = full_faiss_predictions["default"]["test"] + + for python_pred_key, faiss_pred_key in zip( + sorted(python_predictions.keys()), sorted(faiss_predictions.keys()) + ): + assert python_pred_key == faiss_pred_key + assert python_predictions[python_pred_key] == pytest.approx( + faiss_predictions[faiss_pred_key] + )