From da7770613a87a123acec3d9300d66a899d31558d Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 15 Jul 2025 02:32:11 +0000 Subject: [PATCH] add opensearch inf-free models --- .../models/opensearch_neural_sparse_models.py | 243 ++++++++++++++++++ mteb/models/overview.py | 2 + 2 files changed, 245 insertions(+) create mode 100644 mteb/models/opensearch_neural_sparse_models.py diff --git a/mteb/models/opensearch_neural_sparse_models.py b/mteb/models/opensearch_neural_sparse_models.py new file mode 100644 index 0000000000..6e5552dfff --- /dev/null +++ b/mteb/models/opensearch_neural_sparse_models.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import sentence_transformers +import torch + +from functools import partial +from mteb.encoder_interface import PromptType +from mteb.model_meta import ModelMeta +from mteb.models.wrapper import Wrapper +from sentence_transformers.sparse_encoder import SparseEncoder + + +v2_training_data = { + "MSMARCO": ["train"], + # not in MTEB. see https://huggingface.co/datasets/sentence-transformers/embedding-training-data + # "eli5_question_answer": ["train"], + # "gooaq_pairs": ["train"], + # "searchQA_top5_snippets": ["train"], + # "squad_pairs": ["train"], + # "stackexchange_duplicate_questions_body_body": ["train"], + # "stackexchange_duplicate_questions_title_title": ["train"], + # "stackexchange_duplicate_questions_title-body_title-body": ["train"], + # "WikiAnswers": ["train"], + # "wikihow": ["train"], + # "yahoo_answers_question_answer": ["train"], + # "yahoo_answers_title_answer": ["train"], + # "yahoo_answers_title_question": ["train"], +} + + +v3_training_data = v2_training_data | { + "HotpotQA": ["train"], + "FEVER": ["train"], + "FIQA": ["train"], + "NFCORPUS": ["train"], + "SCIFACT": ["train"], + # not in MTEB. see https://huggingface.co/datasets/sentence-transformers/embedding-training-data + # "NQ-train_pairs": ["train"], + # "quora_duplicates": ["train"], +} + + +class SparseEncoderWrapper(Wrapper): + def __init__( + self, + model_name: str, + torch_dtype: torch.dtype = torch.float16, + **kwargs, + ): + if sentence_transformers.__version__ < "5.0.0": + raise ImportError( + "sentence-transformers version must be >= 5.0.0 to load sparse encoder" + ) + + self.model_name = model_name + self.kwargs = kwargs + self.model = SparseEncoder(model_name, **kwargs) + self.model.to(torch_dtype) + self.batch_size = kwargs.get("batch_size", 1000) + + def similarity( + self, query_embeddings: torch.Tensor, corpus_embeddings: torch.Tensor + ) -> torch.Tensor: + """ + Compute similarity between sparse query_embeddings and corpus_embeddings in batches. + + Args: + query_embeddings (Tensor): sparse COO tensor of shape (num_queries, dim) + corpus_embeddings (Tensor): tensor of shape (num_corpus, dim) + + Returns: + Tensor: similarity matrix of shape (num_queries, num_corpus) + """ + sims = [] + num_queries = query_embeddings.size(0) + batch_size = self.batch_size + + # Ensure query_embeddings is coalesced sparse COO + q = query_embeddings.coalesce() + indices = q.indices() # 2 x nnz: [row, col] + values = q.values() # nnz + n_cols = q.size(1) + + # Iterate over sparse query embeddings in batches + for start in range(0, num_queries, batch_size): + end = min(start + batch_size, num_queries) + # Select non-zero entries for this batch + mask = (indices[0] >= start) & (indices[0] < end) + sel_idx = indices[:, mask].clone() + sel_idx[0] -= start # shift row indices to batch-local + sel_vals = values[mask].clone() + + # Build sparse batch tensor of shape (batch_rows, dim) + batch_q = torch.sparse_coo_tensor( + sel_idx, + sel_vals, + size=(end - start, n_cols), + device=q.device, + dtype=q.dtype, + ).coalesce() + + # Compute similarity for this sparse batch + sim_batch = self.model.similarity(batch_q, corpus_embeddings) + sims.append(sim_batch) + + # Concatenate all batch results + return torch.cat(sims, dim=0) + + def encode( + self, sentences: Sequence[str], prompt_type: PromptType | None = None, **kwargs + ): + if prompt_type is not None and prompt_type == PromptType.query: + return self.model.encode_query( + sentences, + **kwargs, + ) + return self.model.encode_document(sentences, **kwargs) + + +opensearch_neural_sparse_encoding_doc_v3_gte = ModelMeta( + name="opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte", + languages=["eng-Latn"], + open_weights=True, + revision="main", + release_date="2025-06-18", + n_parameters=137_394_234, + memory_usage_mb=549, + embed_dim=30522, + license="apache-2.0", + max_tokens=8192, + reference="https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte", + similarity_fn_name="dot", + framework=["Sentence Transformers", "PyTorch"], + public_training_code="https://github.com/zhichao-aws/opensearch-sparse-model-tuning-sample", + public_training_data=True, + use_instructions=True, + training_datasets=v3_training_data, + loader=partial( + SparseEncoderWrapper, + model_name="opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte", + trust_remote_code=True, + ), +) + + +opensearch_neural_sparse_encoding_doc_v3_distill = ModelMeta( + name="opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill", + languages=["eng-Latn"], + open_weights=True, + revision="main", + release_date="2025-03-28", + n_parameters=66_985_530, + memory_usage_mb=267, + embed_dim=30522, + license="apache-2.0", + max_tokens=512, + reference="https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill", + similarity_fn_name="dot", + framework=["Sentence Transformers", "PyTorch"], + public_training_code="https://github.com/zhichao-aws/opensearch-sparse-model-tuning-sample", + public_training_data=True, + use_instructions=True, + training_datasets=v3_training_data, + loader=partial( + SparseEncoderWrapper, + model_name="opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill", + ), +) + +opensearch_neural_sparse_encoding_doc_v2_distill = ModelMeta( + name="opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill", + languages=["eng-Latn"], + open_weights=True, + revision="main", + release_date="2024-07-17", + n_parameters=66_985_530, + memory_usage_mb=267, + embed_dim=30522, + license="apache-2.0", + max_tokens=512, + reference="https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill", + similarity_fn_name="dot", + framework=["Sentence Transformers", "PyTorch"], + public_training_code="https://github.com/zhichao-aws/opensearch-sparse-model-tuning-sample", + public_training_data=True, + use_instructions=True, + training_datasets=v2_training_data, + loader=partial( + SparseEncoderWrapper, + model_name="opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill", + ), +) + + +opensearch_neural_sparse_encoding_doc_v2_mini = ModelMeta( + name="opensearch-project/opensearch-neural-sparse-encoding-doc-v2-mini", + languages=["eng-Latn"], + open_weights=True, + revision="main", + release_date="2024-07-18", + n_parameters=22_744_506, + memory_usage_mb=86, + embed_dim=30522, + license="apache-2.0", + max_tokens=512, + reference="https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-doc-v2-mini", + similarity_fn_name="dot", + framework=["Sentence Transformers", "PyTorch"], + public_training_code="https://github.com/zhichao-aws/opensearch-sparse-model-tuning-sample", + public_training_data=True, + use_instructions=True, + training_datasets=v2_training_data, + loader=partial( + SparseEncoderWrapper, + model_name="opensearch-project/opensearch-neural-sparse-encoding-doc-v2-mini", + ), +) + +opensearch_neural_sparse_encoding_doc_v1 = ModelMeta( + name="opensearch-project/opensearch-neural-sparse-encoding-doc-v1", + languages=["eng-Latn"], + open_weights=True, + revision="main", + release_date="2024-03-07", + n_parameters=132_955_194, + memory_usage_mb=507, + embed_dim=30522, + license="apache-2.0", + max_tokens=512, + reference="https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-doc-v1", + similarity_fn_name="dot", + framework=["Sentence Transformers", "PyTorch"], + public_training_code="https://github.com/zhichao-aws/opensearch-sparse-model-tuning-sample", + public_training_data=True, + use_instructions=True, + training_datasets={ + "MSMARCO": ["train"], + }, + loader=partial( + SparseEncoderWrapper, + model_name="opensearch-project/opensearch-neural-sparse-encoding-doc-v1", + ), +) diff --git a/mteb/models/overview.py b/mteb/models/overview.py index c3d842fb80..75d2b0bab7 100644 --- a/mteb/models/overview.py +++ b/mteb/models/overview.py @@ -68,6 +68,7 @@ nvidia_models, openai_models, openclip_models, + opensearch_neural_sparse_models, ops_moa_models, piccolo_models, promptriever_models, @@ -150,6 +151,7 @@ nvidia_llama_nemoretriever_colemb, openai_models, openclip_models, + opensearch_neural_sparse_models, ops_moa_models, piccolo_models, gme_v_models,