diff --git a/mteb/encoder_interface.py b/mteb/encoder_interface.py index abb88f3ac4..6ebee5620a 100644 --- a/mteb/encoder_interface.py +++ b/mteb/encoder_interface.py @@ -14,7 +14,7 @@ class PromptType(str, Enum): query = "query" - passage = "passage" + document = "document" @runtime_checkable diff --git a/mteb/evaluation/evaluators/Image/Any2AnyRetrievalEvaluator.py b/mteb/evaluation/evaluators/Image/Any2AnyRetrievalEvaluator.py index 06fb66c0a9..79388451ca 100644 --- a/mteb/evaluation/evaluators/Image/Any2AnyRetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/Image/Any2AnyRetrievalEvaluator.py @@ -195,7 +195,7 @@ def search( sub_corpus_embeddings = self.model.get_text_embeddings( texts=corpus_texts, task_name=task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, **self.encode_kwargs, ) else: @@ -213,7 +213,7 @@ def search( sub_corpus_embeddings = self.model.get_image_embeddings( images=corpus_image_dataloader, task_name=task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, **self.encode_kwargs, ) elif corpus_modality == "image,text": @@ -222,7 +222,7 @@ def search( texts=corpus_texts, images=corpus_image_dataloader, task_name=task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, **self.encode_kwargs, ) else: diff --git a/mteb/evaluation/evaluators/RerankingEvaluator.py b/mteb/evaluation/evaluators/RerankingEvaluator.py index 3c45126bdf..259a57b782 100644 --- a/mteb/evaluation/evaluators/RerankingEvaluator.py +++ b/mteb/evaluation/evaluators/RerankingEvaluator.py @@ -180,7 +180,7 @@ def _encode_candidates_batched( all_docs, model, task_name=self.task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, **self.encode_kwargs, ) @@ -245,7 +245,7 @@ def _encode_candidates_individual( model.encode( docs, task_name=self.task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, **self.encode_kwargs, ) ) @@ -293,7 +293,7 @@ def _encode_candidates_miracl_batched(self, all_query_embs, model: Encoder): model.encode( all_docs, task_name=self.task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, **self.encode_kwargs, ) ) @@ -345,7 +345,7 @@ def _encode_candidates_miracl_individual(self, model: Encoder): model.encode( docs, task_name=self.task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, **self.encode_kwargs, ) ) diff --git a/mteb/evaluation/evaluators/RetrievalEvaluator.py b/mteb/evaluation/evaluators/RetrievalEvaluator.py index 69034b1741..9f06df2880 100644 --- a/mteb/evaluation/evaluators/RetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/RetrievalEvaluator.py @@ -159,7 +159,7 @@ def search( sub_corpus_embeddings = self.model.encode( corpus[corpus_start_idx:corpus_end_idx], # type: ignore task_name=task_name, - prompt_type=PromptType.passage, + prompt_type=PromptType.document, request_qid=request_qid, **self.encode_kwargs, ) @@ -385,7 +385,7 @@ def encode_corpus( corpus: list[dict[str, str]], task_name: str, batch_size: int, - prompt_type: PromptType = PromptType.passage, + prompt_type: PromptType = PromptType.document, request_qid: str | None = None, **kwargs, ): @@ -416,7 +416,7 @@ def encode( prompt_type: PromptType | None = None, **kwargs, ): - if prompt_type and prompt_type == PromptType.passage: + if prompt_type and prompt_type == PromptType.document: return self.encode_corpus( sentences, task_name, prompt_type=prompt_type, **kwargs ) diff --git a/mteb/models/cadet_models.py b/mteb/models/cadet_models.py index c2717128cb..81e0dffdb0 100644 --- a/mteb/models/cadet_models.py +++ b/mteb/models/cadet_models.py @@ -35,7 +35,7 @@ revision="8056d118be37a566f20972a5f35cda815f6bc47e", model_prompts={ "query": "query: ", - "passage": "passage: ", + "document": "passage: ", }, ), name="manveertamber/cadet-embed-base-v1", diff --git a/mteb/models/cohere_models.py b/mteb/models/cohere_models.py index 2a40f47218..606195417a 100644 --- a/mteb/models/cohere_models.py +++ b/mteb/models/cohere_models.py @@ -211,7 +211,7 @@ def encode( "MultilabelClassification": "classification", "Clustering": "clustering", PromptType.query.value: "search_query", - PromptType.passage.value: "search_document", + PromptType.document.value: "search_document", } cohere_mult_3 = ModelMeta( diff --git a/mteb/models/e5_models.py b/mteb/models/e5_models.py index 395d1da311..6a62f6e8fd 100644 --- a/mteb/models/e5_models.py +++ b/mteb/models/e5_models.py @@ -110,7 +110,7 @@ model_prompts = { PromptType.query.value: "query: ", - PromptType.passage.value: "passage: ", + PromptType.document.value: "passage: ", } E5_TRAINING_DATA = { diff --git a/mteb/models/gme_v_models.py b/mteb/models/gme_v_models.py index 29e7de8c78..f3a2277ab8 100644 --- a/mteb/models/gme_v_models.py +++ b/mteb/models/gme_v_models.py @@ -182,7 +182,7 @@ def encode_corpus(self, corpus: list[dict[str, str]], **kwargs): else doc["text"].strip() for doc in corpus ] - embeddings = self.encode(sentences, prompt_type=PromptType.passage**kwargs) + embeddings = self.encode(sentences, prompt_type=PromptType.document**kwargs) return embeddings def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs): @@ -210,7 +210,7 @@ def get_fused_embeddings( instruction=None, **kwargs: Any, ): - if prompt_type == PromptType.passage: + if prompt_type == PromptType.document: instruction = None elif instruction is None: instruction = self.get_instruction(task_name, prompt_type) diff --git a/mteb/models/google_models.py b/mteb/models/google_models.py index 1ad9aa0a61..667bf31cda 100644 --- a/mteb/models/google_models.py +++ b/mteb/models/google_models.py @@ -39,7 +39,7 @@ "Clustering": "CLUSTERING", "STS": "SEMANTIC_SIMILARITY", PromptType.query.value: "RETRIEVAL_QUERY", - PromptType.passage.value: "RETRIEVAL_DOCUMENT", + PromptType.document.value: "RETRIEVAL_DOCUMENT", } GECKO_TRAINING_DATA = { diff --git a/mteb/models/instruct_wrapper.py b/mteb/models/instruct_wrapper.py index 1cd7ae9db6..c5be2a672d 100644 --- a/mteb/models/instruct_wrapper.py +++ b/mteb/models/instruct_wrapper.py @@ -147,10 +147,13 @@ def encode( ) # to passage prompts won't be applied to passages - if not self.apply_instruction_to_passages and prompt_type == PromptType.passage: + if ( + not self.apply_instruction_to_passages + and prompt_type == PromptType.document + ): instruction = None logger.info( - f"No instruction used, because prompt type = {prompt_type.passage}" + f"No instruction used, because prompt type = {prompt_type.document}" ) if instruction: diff --git a/mteb/models/jasper_models.py b/mteb/models/jasper_models.py index b3a3f62516..cfdc3849df 100644 --- a/mteb/models/jasper_models.py +++ b/mteb/models/jasper_models.py @@ -44,7 +44,7 @@ def encode( instruction = self.get_task_instruction(task_name, prompt_type) # to passage prompts won't be applied to passages - if prompt_type == PromptType.passage and task.metadata.category == "s2p": + if prompt_type == PromptType.document and task.metadata.category == "s2p": instruction = None embeddings = self.model.encode( diff --git a/mteb/models/jina_models.py b/mteb/models/jina_models.py index bbce9748af..30b7878583 100644 --- a/mteb/models/jina_models.py +++ b/mteb/models/jina_models.py @@ -269,8 +269,8 @@ def _resolve_task_parameters( # Determine prompt name parameter if jina_task_name and "query" in jina_task_name: prompt_name_param = "query" - elif jina_task_name and "passage" in jina_task_name: - prompt_name_param = "passage" + elif jina_task_name and "document" in jina_task_name: + prompt_name_param = "document" else: prompt_name_param = "query" # default fallback @@ -549,7 +549,7 @@ def get_programming_task_override( trust_remote_code=True, model_prompts={ "Retrieval-query": "retrieval.query", - "Retrieval-passage": "retrieval.passage", + "Retrieval-document": "retrieval.passage", "STS": "text-matching", "DocumentUnderstanding": "retrieval.query", }, @@ -584,7 +584,7 @@ def get_programming_task_override( trust_remote_code=True, model_prompts={ "Retrieval-query": "retrieval.query", - "Retrieval-passage": "retrieval.passage", + "Retrieval-document": "retrieval.passage", "Clustering": "separation", "Classification": "classification", "STS": "text-matching", diff --git a/mteb/models/kalm_models.py b/mteb/models/kalm_models.py index 68f17460c3..7ea133a84c 100644 --- a/mteb/models/kalm_models.py +++ b/mteb/models/kalm_models.py @@ -7,6 +7,7 @@ import numpy as np import torch + from mteb.encoder_interface import PromptType from mteb.model_meta import ModelMeta from mteb.models.instruct_wrapper import InstructSentenceTransformerWrapper @@ -37,10 +38,13 @@ def encode( task = get_task(task_name) # to passage prompts won't be applied to passages - if not self.apply_instruction_to_passages and prompt_type == PromptType.passage: + if ( + not self.apply_instruction_to_passages + and prompt_type == PromptType.document + ): instruction = None logger.info( - f"No instruction used, because prompt type = {prompt_type.passage}" + f"No instruction used, because prompt type = {prompt_type.document}" ) if task.metadata.type in ["STS", "PairClassification", "Summarization"]: @@ -375,26 +379,26 @@ def encode( "ThuNewsClusteringS2S": "Identify the topic or theme of the given news articles based on the titles", "ThuNewsClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents", "Cmnli-query": "Retrieve semantically similar text", - "Cmnli-passage": "Retrieve semantically similar text", + "Cmnli-document": "Retrieve semantically similar text", "Ocnli-query": "Retrieve semantically similar text", - "Ocnli-passage": "Retrieve semantically similar text", + "Ocnli-document": "Retrieve semantically similar text", "SprintDuplicateQuestions-query": "Retrieve semantically similar questions", - "SprintDuplicateQuestions-passage": "Retrieve semantically similar questions", + "SprintDuplicateQuestions-document": "Retrieve semantically similar questions", "TwitterSemEval2015-query": "Retrieve semantically similar text", - "TwitterSemEval2015-passage": "Retrieve semantically similar text", + "TwitterSemEval2015-document": "Retrieve semantically similar text", "TwitterURLCorpus-query": "Retrieve semantically similar text", - "TwitterURLCorpus-passage": "Retrieve semantically similar text", + "TwitterURLCorpus-document": "Retrieve semantically similar text", "CMedQAv1-reranking": "Given a query, retrieve documents that answer the query", "CMedQAv2-reranking": "Given a query, retrieve documents that answer the query", "MMarcoReranking": "Given a query, retrieve documents that answer the query", "T2Reranking": "Given a query, retrieve documents that answer the query", "AskUbuntuDupQuestions-query": "Retrieve semantically similar questions", - "AskUbuntuDupQuestions-passage": "Retrieve semantically similar questions", + "AskUbuntuDupQuestions-document": "Retrieve semantically similar questions", "MindSmallReranking": "Given a query, retrieve documents that answer the query", "SciDocsRR-query": "Retrieve relevant paper titles", - "SciDocsRR-passage": "Retrieve relevant paper titles", + "SciDocsRR-document": "Retrieve relevant paper titles", "StackOverflowDupQuestions-query": "Retrieve semantically similar questions", - "StackOverflowDupQuestions-passage": "Retrieve semantically similar questions", + "StackOverflowDupQuestions-document": "Retrieve semantically similar questions", "CmedqaRetrieval": "Given a query, retrieve documents that answer the query", "CovidRetrieval": "Given a query, retrieve documents that answer the query", "DuRetrieval": "Given a query, retrieve documents that answer the query", @@ -407,29 +411,29 @@ def encode( "ArguAna": "Given a query, retrieve documents that answer the query", "ClimateFEVER": "Given a query, retrieve documents that answer the query", "CQADupstackAndroidRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackAndroidRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackAndroidRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackEnglishRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackEnglishRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackEnglishRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackGamingRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackGamingRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackGamingRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackGisRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackGisRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackGisRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackMathematicaRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackMathematicaRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackMathematicaRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackPhysicsRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackPhysicsRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackPhysicsRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackProgrammersRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackProgrammersRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackProgrammersRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackStatsRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackStatsRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackStatsRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackTexRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackTexRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackTexRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackUnixRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackUnixRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackUnixRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackWebmastersRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackWebmastersRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackWebmastersRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackWordpressRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackWordpressRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackWordpressRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "DBPedia": "Given a query, retrieve documents that answer the query", "FEVER": "Given a query, retrieve documents that answer the query", "FiQA2018": "Given a query, retrieve documents that answer the query", @@ -437,48 +441,48 @@ def encode( "NFCorpus": "Given a query, retrieve documents that answer the query", "NQ": "Given a query, retrieve documents that answer the query", "QuoraRetrieval-query": "Retrieve semantically similar questions", - "QuoraRetrieval-passage": "Retrieve semantically similar questions", + "QuoraRetrieval-document": "Retrieve semantically similar questions", "SCIDOCS-query": "Given a query, retrieve documents that answer the query", - "SCIDOCS-passage": "Given a query, retrieve documents that answer the query", + "SCIDOCS-document": "Given a query, retrieve documents that answer the query", "SciFact": "Given a query, retrieve documents that answer the query", "Touche2020": "Given a query, retrieve documents that answer the query", "TRECCOVID": "Given a query, retrieve documents that answer the query", "AFQMC-query": "Retrieve semantically similar text", - "AFQMC-passage": "Retrieve semantically similar text", + "AFQMC-document": "Retrieve semantically similar text", "ATEC-query": "Retrieve semantically similar text", - "ATEC-passage": "Retrieve semantically similar text", + "ATEC-document": "Retrieve semantically similar text", "BQ-query": "Retrieve semantically similar text", - "BQ-passage": "Retrieve semantically similar text", + "BQ-document": "Retrieve semantically similar text", "LCQMC-query": "Retrieve semantically similar text", - "LCQMC-passage": "Retrieve semantically similar text", + "LCQMC-document": "Retrieve semantically similar text", "PAWSX-query": "Retrieve semantically similar text", - "PAWSX-passage": "Retrieve semantically similar text", + "PAWSX-document": "Retrieve semantically similar text", "QBQTC-query": "Retrieve semantically similar text", - "QBQTC-passage": "Retrieve semantically similar text", + "QBQTC-document": "Retrieve semantically similar text", "STSB-query": "Retrieve semantically similar text", - "STSB-passage": "Retrieve semantically similar text", + "STSB-document": "Retrieve semantically similar text", "BIOSSES-query": "Retrieve semantically similar text", - "BIOSSES-passage": "Retrieve semantically similar text", + "BIOSSES-document": "Retrieve semantically similar text", "SICK-R-query": "Retrieve semantically similar text", - "SICK-R-passage": "Retrieve semantically similar text", + "SICK-R-document": "Retrieve semantically similar text", "STS12-query": "Retrieve semantically similar text", - "STS12-passage": "Retrieve semantically similar text", + "STS12-document": "Retrieve semantically similar text", "STS13-query": "Retrieve semantically similar text", - "STS13-passage": "Retrieve semantically similar text", + "STS13-document": "Retrieve semantically similar text", "STS14-query": "Retrieve semantically similar text", - "STS14-passage": "Retrieve semantically similar text", + "STS14-document": "Retrieve semantically similar text", "STS15-query": "Retrieve semantically similar text", - "STS15-passage": "Retrieve semantically similar text", + "STS15-document": "Retrieve semantically similar text", "STS16-query": "Retrieve semantically similar text", - "STS16-passage": "Retrieve semantically similar text", + "STS16-document": "Retrieve semantically similar text", "STS17-query": "Retrieve semantically similar text", - "STS17-passage": "Retrieve semantically similar text", + "STS17-document": "Retrieve semantically similar text", "STS22-query": "Retrieve semantically similar text", - "STS22-passage": "Retrieve semantically similar text", + "STS22-document": "Retrieve semantically similar text", "STSBenchmark-query": "Retrieve semantically similar text", - "STSBenchmark-passage": "Retrieve semantically similar text", + "STSBenchmark-document": "Retrieve semantically similar text", "SummEval-query": "Retrieve semantically similar summaries", - "SummEval-passage": "Retrieve semantically similar summaries", + "SummEval-document": "Retrieve semantically similar summaries", } KaLM_X_task_prompts = { @@ -530,7 +534,7 @@ def encode( "CMedQAv1-reranking-query": "Given a Chinese community medical question, retrieve replies that best answer the question", "CMedQAv2-reranking-query": "Given a Chinese community medical question, retrieve replies that best answer the question", "ArguAna-query": "Given a claim, find documents that refute the claim", - "ArguAna-passage": "Given a claim, find documents that refute the claim", + "ArguAna-document": "Given a claim, find documents that refute the claim", "ClimateFEVER-query": "Given a claim about climate change, retrieve documents that support or refute the claim", "ClimateFEVERHardNegatives-query": "Given a claim about climate change, retrieve documents that support or refute the claim", "DBPedia-query": "Given a query, retrieve relevant entity descriptions from DBPedia", @@ -677,9 +681,9 @@ def encode( "MIRACLRetrievalHardNegatives-query": "Retrieval relevant passage for the given query", "CQADupstackRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackGamingRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackGamingRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackGamingRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", "CQADupstackUnixRetrieval-query": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", - "CQADupstackUnixRetrieval-passage": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", + "CQADupstackUnixRetrieval-document": "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question", } KaLM_INSTRUCTION = "Instruct: {instruction} \n Query: " diff --git a/mteb/models/mcinext_models.py b/mteb/models/mcinext_models.py index a5f02279ca..ce70a4bbb9 100644 --- a/mteb/models/mcinext_models.py +++ b/mteb/models/mcinext_models.py @@ -256,7 +256,7 @@ def _preprocess_sample( if task_id == 3: if sub == "sentence1" or (prompt_type and prompt_type.value == "query"): return f"{task_prompt} | متن اول : {sample}" - if sub == "sentence2" or (prompt_type and prompt_type.value == "passage"): + if sub == "sentence2" or (prompt_type and prompt_type.value == "document"): return f"{task_prompt} | متن دوم : {sample}" return sample diff --git a/mteb/models/nomic_models.py b/mteb/models/nomic_models.py index f02157aa09..78d7163ef8 100644 --- a/mteb/models/nomic_models.py +++ b/mteb/models/nomic_models.py @@ -56,7 +56,7 @@ def encode( # type: ignore # default to search_document if input_type and prompt_name are not provided prompt_name = ( self.get_prompt_name(self.model_prompts, task_name, prompt_type) - or PromptType.passage.value + or PromptType.document.value ) task = mteb.get_task(task_name) # normalization not applied to classification @@ -177,7 +177,7 @@ def encode( # type: ignore "STS": "classification: ", "Summarization": "classification: ", PromptType.query.value: "search_query: ", - PromptType.passage.value: "search_document: ", + PromptType.document.value: "search_document: ", } nomic_embed_v1_5 = ModelMeta( diff --git a/mteb/models/qwen3_models.py b/mteb/models/qwen3_models.py index 3571cc99c1..33862939b4 100644 --- a/mteb/models/qwen3_models.py +++ b/mteb/models/qwen3_models.py @@ -11,7 +11,7 @@ def instruction_template( instruction: str, prompt_type: PromptType | None = None ) -> str: - if not instruction or prompt_type == PromptType.passage: + if not instruction or prompt_type == PromptType.document: return "" if isinstance(instruction, dict): if prompt_type is None: diff --git a/mteb/models/repllama_models.py b/mteb/models/repllama_models.py index eed13cd71b..e704cb865a 100644 --- a/mteb/models/repllama_models.py +++ b/mteb/models/repllama_models.py @@ -121,7 +121,7 @@ def loader_inner(**kwargs: Any) -> Encoder: model_prompts = { PromptType.query.value: "query: ", - PromptType.passage.value: "passage: ", + PromptType.document.value: "passage: ", } repllama_llama2_original = ModelMeta( diff --git a/mteb/models/ru_sentence_models.py b/mteb/models/ru_sentence_models.py index 2a9c7063a4..9c3be1a76a 100644 --- a/mteb/models/ru_sentence_models.py +++ b/mteb/models/ru_sentence_models.py @@ -35,23 +35,23 @@ "SensitiveTopicsClassification": "Классифицируй чувствительную тему по запросу \nзапрос: ", "RuBQRetrieval": { "query": "Given a question, retrieve Wikipedia passages that answer the question\nquery: ", - "passage": "", + "document": "", }, "RuBQReranking": { "query": "Given a question, retrieve Wikipedia passages that answer the question\nquery: ", - "passage": "", + "document": "", }, "RiaNewsRetrieval": { "query": "Given a news title, retrieve relevant news article\nquery: ", - "passage": "", + "document": "", }, "MIRACLReranking": { "query": "Given a question, retrieve Wikipedia passages that answer the question\nquery: ", - "passage": "", + "document": "", }, "MIRACLRetrieval": { "query": "Given a question, retrieve Wikipedia passages that answer the question\nquery: ", - "passage": "", + "document": "", }, } @@ -156,7 +156,7 @@ sentence_transformers_loader, model_name="deepvk/USER-base", revision="436a489a2087d61aa670b3496a9915f84e46c861", - model_prompts={"query": "query: ", "passage": "passage: "}, + model_prompts={"query": "query: ", "document": "passage: "}, ), name="deepvk/USER-base", languages=["rus-Cyrl"], @@ -431,11 +431,11 @@ "PairClassification": "classification: ", "Reranking": "classification: ", f"Reranking-{PromptType.query.value}": "search_query: ", - f"Reranking-{PromptType.passage.value}": "search_document: ", + f"Reranking-{PromptType.document.value}": "search_document: ", "STS": "classification: ", "Summarization": "classification: ", PromptType.query.value: "search_query: ", - PromptType.passage.value: "search_document: ", + PromptType.document.value: "search_document: ", # Override some prompts for ruMTEB tasks "HeadlineClassification": "clustering: ", "InappropriatenessClassification": "clustering: ", @@ -497,11 +497,11 @@ "PairClassification": "paraphrase: ", "Reranking": "paraphrase: ", f"Reranking-{PromptType.query.value}": "search_query: ", - f"Reranking-{PromptType.passage.value}": "search_document: ", + f"Reranking-{PromptType.document.value}": "search_document: ", "STS": "paraphrase: ", "Summarization": "categorize: ", PromptType.query.value: "search_query: ", - PromptType.passage.value: "search_document: ", + PromptType.document.value: "search_document: ", # Override some prompts for ruMTEB tasks "CEDRClassification": "categorize_sentiment: ", "GeoreviewClassification": "categorize_sentiment: ", @@ -751,11 +751,11 @@ "PairClassification": "classification: ", "Reranking": "classification: ", f"Reranking-{PromptType.query.value}": "search_query: ", - f"Reranking-{PromptType.passage.value}": "search_document: ", + f"Reranking-{PromptType.document.value}": "search_document: ", "STS": "classification: ", "Summarization": "clustering: ", PromptType.query.value: "search_query: ", - PromptType.passage.value: "search_document: ", + PromptType.document.value: "search_document: ", } user2_small = ModelMeta( loader=partial( diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index 872a3fffe9..e9d5492803 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -45,6 +45,13 @@ def __init__( ): try: model_prompts = self.validate_task_to_prompt_name(self.model.prompts) + + if ( + len(self.model.prompts) == 2 + and self.model.prompts.get("query", "") == "" + and self.model.prompts.get("document", "") == "" + ): + model_prompts = None except KeyError: model_prompts = None logger.warning( @@ -89,24 +96,24 @@ def encode( Returns: The encoded sentences. """ + prompt = None prompt_name = None if self.model_prompts is not None: prompt_name = self.get_prompt_name( self.model_prompts, task_name, prompt_type ) + prompt = self.model_prompts.get(prompt_name, None) if prompt_name: logger.info( - f"Using prompt_name={prompt_name} for task={task_name} prompt_type={prompt_type}" + f"Using {prompt_name=} for task={task_name} {prompt_type=} with {prompt=}" ) else: - logger.info( - f"No model prompts found for task={task_name} prompt_type={prompt_type}" - ) + logger.info(f"No model prompts found for task={task_name} {prompt_type=}") logger.info(f"Encoding {len(sentences)} sentences.") embeddings = self.model.encode( sentences, - prompt_name=prompt_name, + prompt=prompt, **kwargs, ) if isinstance(embeddings, torch.Tensor): diff --git a/mteb/models/vlm2vec_models.py b/mteb/models/vlm2vec_models.py index 65ca7b4004..38823d77a7 100644 --- a/mteb/models/vlm2vec_models.py +++ b/mteb/models/vlm2vec_models.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -EncodeTypes = Literal["query", "passage"] +EncodeTypes = Literal["query", "document"] class VLM2VecWrapper: diff --git a/mteb/models/voyage_models.py b/mteb/models/voyage_models.py index b1eb33442a..c8c58baef4 100644 --- a/mteb/models/voyage_models.py +++ b/mteb/models/voyage_models.py @@ -140,7 +140,7 @@ def _batched_encode( model_prompts = { PromptType.query.value: "query", - PromptType.passage.value: "document", + PromptType.document.value: "document", } voyage_large_2_instruct = ModelMeta( diff --git a/mteb/models/voyage_v.py b/mteb/models/voyage_v.py index 48a083574c..fada131681 100644 --- a/mteb/models/voyage_v.py +++ b/mteb/models/voyage_v.py @@ -87,7 +87,7 @@ def get_text_embeddings( **kwargs: Any, ): if input_type is None and prompt_type is not None: - if prompt_type == PromptType.passage: + if prompt_type == PromptType.document: input_type = "document" elif prompt_type == PromptType.query: input_type = "query" @@ -119,7 +119,7 @@ def get_image_embeddings( **kwargs: Any, ): if input_type is None and prompt_type is not None: - if prompt_type == PromptType.passage: + if prompt_type == PromptType.document: input_type = "document" elif prompt_type == PromptType.query: input_type = "query" @@ -175,7 +175,7 @@ def get_fused_embeddings( raise ValueError("Either texts or images must be provided") if input_type is None and prompt_type is not None: - if prompt_type == PromptType.passage: + if prompt_type == PromptType.document: input_type = "document" elif prompt_type == PromptType.query: input_type = "query" diff --git a/mteb/models/wrapper.py b/mteb/models/wrapper.py index 73ebdde341..68ec09ae62 100644 --- a/mteb/models/wrapper.py +++ b/mteb/models/wrapper.py @@ -35,7 +35,7 @@ def get_prompt_name( Args: task_to_prompt: The tasks names and their corresponding prompt_names task_name: The task name to use for building the encoding prompt - prompt_type: The prompt type (e.g. "query" | "passage") to use for building the encoding prompt + prompt_type: The prompt type (e.g. "query" | "document") to use for building the encoding prompt """ task = mteb.get_task(task_name=task_name) task_type = task.metadata.type @@ -74,7 +74,7 @@ def validate_task_to_prompt_name( prompt_types = [e.value for e in PromptType] for task_name in task_to_prompt_name: if "-" in task_name and task_name.endswith( - (f"-{PromptType.query.value}", f"-{PromptType.passage.value}") + (f"-{PromptType.query.value}", f"-{PromptType.document.value}") ): task_name, prompt_type = task_name.rsplit("-", 1) if prompt_type not in prompt_types: diff --git a/tests/test_benchmark/mock_models.py b/tests/test_benchmark/mock_models.py index 8fd45fd6bd..4feb6820a0 100644 --- a/tests/test_benchmark/mock_models.py +++ b/tests/test_benchmark/mock_models.py @@ -171,7 +171,7 @@ def __init__( model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model. revision: The revision of the model to use. model_prompts: A dictionary mapping task names to prompt names. - First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt, + First priority is given to the composed prompt of task name + prompt type (query or document), then to the specific task prompt, then to the composed prompt of task type + prompt type, then to the specific task type prompt, and finally to the specific prompt type. **kwargs: Additional arguments to pass to the SentenceTransformer model. diff --git a/tests/test_benchmark/test_benchmark.py b/tests/test_benchmark/test_benchmark.py index 11b5f4cc7f..69487b2805 100644 --- a/tests/test_benchmark/test_benchmark.py +++ b/tests/test_benchmark/test_benchmark.py @@ -115,7 +115,7 @@ def encode(self, sentences, prompt_name: str | None = None, **kwargs): class EncoderWithoutInstructions(MockSentenceTransformer): def encode(self, sentences, **kwargs): - assert kwargs["prompt_name"] is None + assert kwargs["prompt"] is None return super().encode(sentences, **kwargs) if isinstance(task_name, mteb.AbsTask): @@ -302,7 +302,7 @@ def encode(self, sentences, prompt_name: str | None = None, **kwargs): ) -@pytest.mark.parametrize("task_name", ["NQ-NL-query", "NQ-NL-passage"]) +@pytest.mark.parametrize("task_name", ["NQ-NL-query", "NQ-NL-document"]) def test_prompt_name_split_correctly(task_name: str, tmp_path: Path): """Test that the task name is split correctly into task name and prompt type for tasks with multiple `-` in their names. @@ -331,12 +331,12 @@ def test_model_query_passage_prompts_task_type( task_name = task.metadata.name if is_task_name else task.metadata.type def check_prompt(prompt_name, is_query): - prompt_type = "query" if is_query else "passage" + prompt_type = "query" if is_query else "document" assert prompt_name == f"{task_name}-{prompt_type}" prompt_list = { f"{task_name}-query": "query", - f"{task_name}-passage": "passage", + f"{task_name}-document": "document", } class MockEncoderWithPrompts(mteb.Encoder): diff --git a/tests/test_reproducible_workflow.py b/tests/test_reproducible_workflow.py index d584a98852..738392c623 100644 --- a/tests/test_reproducible_workflow.py +++ b/tests/test_reproducible_workflow.py @@ -57,10 +57,12 @@ def test_validate_task_to_prompt_name(task_name: str | mteb.AbsTask): model_prompts = dict.fromkeys(task_names, "prompt_name") model_prompts |= {task_name + "-query": "prompt_name" for task_name in task_names} - model_prompts |= {task_name + "-passage": "prompt_name" for task_name in task_names} + model_prompts |= { + task_name + "-document": "prompt_name" for task_name in task_names + } model_prompts |= { "query": "prompt_name", - "passage": "prompt_name", + "document": "prompt_name", } Wrapper.validate_task_to_prompt_name(model_prompts)