-
Notifications
You must be signed in to change notification settings - Fork 578
feat: add search encoder backend #3492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c8b2bd3
8a3527f
b2c3f60
51111ca
ae31d1b
2ce10fd
74458c5
05b0ba8
48143c0
7fbc60f
97ab832
f9b0c8b
91aad74
9748dc9
f101699
c69be3a
1c51674
1fb8a88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will be shown in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but people will not know what has happened since 2.0.0 I would probably change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is more about changelog #3401
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair we still need the API docs though |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # Retrieval Search backend | ||
|
|
||
| !!! note "Available since 2.3.0" | ||
| This feature was introduced in version **2.3.0**. | ||
|
|
||
| For some large dataset search can take a lot of time and memory. To reduce this you can use `FaissSearchIndex`. To work with it install `pip install mteb[faiss]`. | ||
|
|
||
| Usage example: | ||
| ```python | ||
| import mteb | ||
| from mteb.models import SearchEncoderWrapper | ||
| from mteb.models.search_encoder_index import FaissSearchIndex | ||
|
|
||
| model = mteb.get_model(...) | ||
| index_backend = FaissSearchIndex(model) | ||
| model = SearchEncoderWrapper( | ||
| model, | ||
| index_backend=index_backend | ||
| ) | ||
| ... | ||
| ``` | ||
|
|
||
| For example running `minishlab/potion-base-2M` on `SWEbenchVerifiedRR` took 694 seconds instead of 769. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| from .cache_backend_protocol import CacheBackendProtocol | ||
| from .cache_wrapper import CachedEmbeddingWrapper | ||
|
|
||
| __all__ = ["CachedEmbeddingWrapper"] | ||
| __all__ = ["CacheBackendProtocol", "CachedEmbeddingWrapper"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| from .search_backend_protocol import IndexEncoderSearchProtocol | ||
| from .search_indexes import FaissSearchIndex | ||
|
|
||
| __all__ = [ | ||
| "FaissSearchIndex", | ||
| "IndexEncoderSearchProtocol", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| from collections.abc import Callable | ||
| from typing import Protocol | ||
|
|
||
| from mteb.types import Array, TopRankedDocumentsType | ||
|
|
||
|
|
||
| class IndexEncoderSearchProtocol(Protocol): | ||
KennethEnevoldsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Protocol for search backends used in encoder-based retrieval.""" | ||
|
|
||
| def add_documents( | ||
| self, | ||
| embeddings: Array, | ||
| idxs: list[str], | ||
| ) -> None: | ||
| """Add documents to the search backend. | ||
|
|
||
| Args: | ||
| embeddings: Embeddings of the documents to add. | ||
| idxs: IDs of the documents to add. | ||
| """ | ||
|
|
||
| def search( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| similarity_fn: Callable[[Array, Array], Array], | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| """Search through added corpus embeddings or rerank top-ranked documents. | ||
|
|
||
| Supports both full-corpus and reranking search modes: | ||
| - Full-corpus mode: `top_ranked=None`, uses added corpus embeddings. | ||
| - Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}. | ||
|
|
||
| Args: | ||
| embeddings: Query embeddings, shape (num_queries, dim). | ||
| top_k: Number of top results to return. | ||
| similarity_fn: Function to compute similarity between query and corpus. | ||
| top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking. | ||
| query_idx_to_id: Mapping of query index -> query_id. Used for reranking. | ||
|
|
||
| Returns: | ||
| A tuple (top_k_values, top_k_indices), for each query: | ||
| - top_k_values: List of top-k similarity scores. | ||
| - top_k_indices: List of indices of the top-k documents in the added corpus. | ||
| """ | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear all stored documents and embeddings from the backend.""" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from .faiss_search_index import FaissSearchIndex | ||
|
|
||
| __all__ = [ | ||
| "FaissSearchIndex", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| import logging | ||
| from collections.abc import Callable | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from mteb._requires_package import requires_package | ||
| from mteb.models.model_meta import ScoringFunction | ||
| from mteb.models.models_protocols import EncoderProtocol | ||
| from mteb.types import Array, TopRankedDocumentsType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class FaissSearchIndex: | ||
| """FAISS-based backend for encoder-based search. | ||
|
|
||
| Supports both full-corpus retrieval and reranking (via `top_ranked`). | ||
|
|
||
| Notes: | ||
| - Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2). | ||
| - Expects embeddings to be normalized if cosine similarity is desired. | ||
| """ | ||
|
|
||
| _normalize: bool = False | ||
|
|
||
| def __init__(self, model: EncoderProtocol) -> None: | ||
| requires_package( | ||
| self, | ||
| "faiss", | ||
| "FAISS-based search", | ||
| install_instruction="pip install mteb[faiss-cpu]", | ||
| ) | ||
|
|
||
| import faiss | ||
| from faiss import IndexFlatIP, IndexFlatL2 | ||
|
|
||
| # https://github.com/facebookresearch/faiss/wiki/Faiss-indexes | ||
| if model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT: | ||
| self.index_type = IndexFlatIP | ||
| elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE: | ||
| self.index_type = IndexFlatIP | ||
| self._normalize = True | ||
| elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.EUCLIDEAN: | ||
| self.index_type = IndexFlatL2 | ||
| else: | ||
| raise ValueError( | ||
| f"FAISS backend does not support similarity function {model.mteb_model_meta.similarity_fn_name}. " | ||
| f"Available: {ScoringFunction.DOT_PRODUCT}, {ScoringFunction.COSINE}." | ||
| ) | ||
|
|
||
| self.idxs: list[str] = [] | ||
| self.index: faiss.Index | None = None | ||
|
|
||
| def add_documents(self, embeddings: Array, idxs: list[str]) -> None: | ||
| """Add all document embeddings and their IDs to FAISS index.""" | ||
| import faiss | ||
|
|
||
| if isinstance(embeddings, torch.Tensor): | ||
| embeddings = embeddings.detach().cpu().numpy() | ||
|
|
||
| embeddings = embeddings.astype(np.float32) | ||
| self.idxs.extend(idxs) | ||
|
|
||
| if self._normalize: | ||
| faiss.normalize_L2(embeddings) | ||
|
|
||
| dim = embeddings.shape[1] | ||
| if self.index is None: | ||
| self.index = self.index_type(dim) | ||
|
|
||
| self.index.add(embeddings) | ||
| logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.") | ||
|
|
||
| def search( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| similarity_fn: Callable[[Array, Array], Array], | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| """Search using FAISS.""" | ||
| import faiss | ||
|
|
||
| if self.index is None: | ||
| raise ValueError("No index built. Call add_document() first.") | ||
|
|
||
| if isinstance(embeddings, torch.Tensor): | ||
| embeddings = embeddings.detach().cpu().numpy() | ||
|
|
||
| if self._normalize: | ||
| faiss.normalize_L2(embeddings) | ||
|
|
||
| if top_ranked is not None: | ||
| if query_idx_to_id is None: | ||
| raise ValueError("query_idx_to_id must be provided when reranking.") | ||
|
|
||
| similarities, ids = self._reranking( | ||
| embeddings, | ||
| top_k, | ||
| top_ranked=top_ranked, | ||
| query_idx_to_id=query_idx_to_id, | ||
| ) | ||
| else: | ||
| similarities, ids = self.index.search(embeddings.astype(np.float32), top_k) | ||
| similarities = similarities.tolist() | ||
| ids = ids.tolist() | ||
|
|
||
| if issubclass(self.index_type, faiss.IndexFlatL2): | ||
| similarities = -np.sqrt(np.maximum(similarities, 0)) | ||
|
|
||
| return similarities, ids | ||
|
|
||
| def _reranking( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} | ||
| scores_all: list[list[float]] = [] | ||
| idxs_all: list[list[int]] = [] | ||
|
|
||
| for query_idx, query_emb in enumerate(embeddings): | ||
| query_id = query_idx_to_id[query_idx] | ||
| ranked_ids = top_ranked.get(query_id) | ||
| if not ranked_ids: | ||
| logger.warning(f"No top-ranked documents for query {query_id}") | ||
| scores_all.append([]) | ||
| idxs_all.append([]) | ||
| continue | ||
|
|
||
| candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] | ||
| d = self.index.d | ||
| candidate_embs = np.vstack( | ||
| [self.index.reconstruct(idx) for idx in candidate_indices] | ||
| ) | ||
| sub_reranking_index = self.index_type(d) | ||
| sub_reranking_index.add(candidate_embs) | ||
|
|
||
| # Search returns scores and indices in one call | ||
| scores, local_indices = sub_reranking_index.search( | ||
| query_emb.reshape(1, -1).astype(np.float32), | ||
| min(top_k, len(candidate_indices)), | ||
| ) | ||
| # faiss will output 2d arrays even for single query | ||
| scores_all.append(scores[0].tolist()) | ||
| idxs_all.append(local_indices[0].tolist()) | ||
|
|
||
| return scores_all, idxs_all | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear all stored documents and embeddings from the backend.""" | ||
| self.index = None | ||
| self.idxs = [] |

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oO does this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, everything after
:will be triggered before running a function