diff --git a/mteb/models/bedrock_models.py b/mteb/models/bedrock_models.py index a97535960c..46e3f02113 100644 --- a/mteb/models/bedrock_models.py +++ b/mteb/models/bedrock_models.py @@ -39,11 +39,7 @@ def __init__( self._provider = provider.lower() if self._provider == "cohere": - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) - if model_prompts - else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) self._max_batch_size = 96 self._max_sequence_length = max_tokens * 4 else: diff --git a/mteb/models/cohere_models.py b/mteb/models/cohere_models.py index 606195417a..dbbfd35dfa 100644 --- a/mteb/models/cohere_models.py +++ b/mteb/models/cohere_models.py @@ -135,9 +135,7 @@ def __init__( ) -> None: self.model_name = model_name self.sep = sep - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) def _embed( self, diff --git a/mteb/models/colbert_models.py b/mteb/models/colbert_models.py index 82c0e1d1e2..c0e22ea306 100644 --- a/mteb/models/colbert_models.py +++ b/mteb/models/colbert_models.py @@ -40,17 +40,11 @@ def __init__( self.model_name = model_name self.model = colbert_model.ColBERT(self.model_name, revision=revision, **kwargs) - if ( - model_prompts is None - and hasattr(self.model, "prompts") - and len(self.model.prompts) > 0 - ): - try: - model_prompts = self.validate_task_to_prompt_name(self.model.prompts) - except ValueError: - model_prompts = None - elif model_prompts is not None and hasattr(self.model, "prompts"): - logger.info(f"Model prompts will be overwritten with {model_prompts}") + built_in_prompts = getattr(self.model, "prompts", None) + if built_in_prompts and not model_prompts: + model_prompts = built_in_prompts + elif model_prompts and built_in_prompts: + logger.info(f"Model.prompts will be overwritten with {model_prompts}") self.model.prompts = model_prompts self.model_prompts = self.validate_task_to_prompt_name(model_prompts) diff --git a/mteb/models/google_models.py b/mteb/models/google_models.py index 2ef93b261b..9636d1ded2 100644 --- a/mteb/models/google_models.py +++ b/mteb/models/google_models.py @@ -60,9 +60,7 @@ def __init__( **kwargs, ) -> None: self.model_name = model_name - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) def _embed( self, diff --git a/mteb/models/llm2vec_models.py b/mteb/models/llm2vec_models.py index 37983bc159..b73678a681 100644 --- a/mteb/models/llm2vec_models.py +++ b/mteb/models/llm2vec_models.py @@ -72,9 +72,7 @@ def __init__( extra_kwargs["attn_implementation"] = "flash_attention_2" - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) if device: kwargs["device_map"] = device diff --git a/mteb/models/repllama_models.py b/mteb/models/repllama_models.py index e704cb865a..2a135712e2 100644 --- a/mteb/models/repllama_models.py +++ b/mteb/models/repllama_models.py @@ -47,9 +47,7 @@ def __init__( # set the max_length for the evals as they did, although the model can handle longer self.model.config.max_length = 512 self.tokenizer.model_max_length = 512 - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) def create_batch_dict(self, tokenizer, input_texts): max_length = self.model.config.max_length diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index e9d5492803..d27584b96c 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -39,28 +39,35 @@ def __init__( self.model = model if ( - model_prompts is None - and hasattr(self.model, "prompts") - and len(self.model.prompts) > 0 - ): - try: - model_prompts = self.validate_task_to_prompt_name(self.model.prompts) - - if ( - len(self.model.prompts) == 2 - and self.model.prompts.get("query", "") == "" - and self.model.prompts.get("document", "") == "" - ): - model_prompts = None - except KeyError: - model_prompts = None - logger.warning( - "Model prompts are not in the expected format. Ignoring them." - ) - elif model_prompts is not None and hasattr(self.model, "prompts"): - logger.info(f"Model prompts will be overwritten with {model_prompts}") + built_in_prompts := getattr(self.model, "prompts", None) + ) and not model_prompts: + model_prompts = built_in_prompts + elif model_prompts and built_in_prompts: + logger.warning(f"Model prompts will be overwritten with {model_prompts}") self.model.prompts = model_prompts - self.model_prompts = self.validate_task_to_prompt_name(model_prompts) + + self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name( + model_prompts, raise_for_invalid_keys=False + ) + + if invalid_prompts: + invalid_prompts = "\n".join(invalid_prompts) + logger.warning( + f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}" + ) + + if ( + self.model_prompts + and len(self.model_prompts) <= 2 + and ( + PromptType.query.value not in self.model_prompts + or PromptType.document.value not in self.model_prompts + ) + ): + logger.warning( + "SentenceTransformers that use prompts most often need to be configured with at least 'query' and" + f" 'document' prompts to ensure optimal performance. Received {self.model_prompts}" + ) if isinstance(self.model, CrossEncoder): self.predict = self._predict diff --git a/mteb/models/voyage_models.py b/mteb/models/voyage_models.py index 15c934d573..2d3b3b100f 100644 --- a/mteb/models/voyage_models.py +++ b/mteb/models/voyage_models.py @@ -85,9 +85,7 @@ def __init__( self._embed_func = rate_limit(max_rpm)(token_limit(max_tpm)(self._client.embed)) self._model_name = model_name self._max_tpm = max_tpm - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) def encode( self, diff --git a/mteb/models/wrapper.py b/mteb/models/wrapper.py index 68ec09ae62..ccbdc59713 100644 --- a/mteb/models/wrapper.py +++ b/mteb/models/wrapper.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Callable, get_args +from collections.abc import Callable, Sequence +from typing import Literal, get_args, overload import mteb from mteb.abstasks.TaskMetadata import TASK_TYPE @@ -65,29 +66,95 @@ def get_prompt_name( return None @staticmethod + @overload def validate_task_to_prompt_name( - task_to_prompt_name: dict[str, str] | None, - ) -> dict[str, str] | None: - if task_to_prompt_name is None: - return task_to_prompt_name + task_to_prompt: dict[str, str] | None, + raise_for_invalid_keys: Literal[True] = True, + ) -> dict[str, str] | None: ... + + @staticmethod + @overload + def validate_task_to_prompt_name( + task_to_prompt: dict[str, str] | None, + raise_for_invalid_keys: Literal[False] = False, + ) -> tuple[dict[str, str], Sequence[str]] | tuple[None, None]: ... + + @staticmethod + def validate_task_to_prompt_name( + task_to_prompt: dict[str, str] | None, + raise_for_invalid_keys: bool = True, + ) -> ( + dict[str, str] | tuple[dict[str, str], Sequence[str]] | tuple[None, None] | None + ): + """Validates that the keys in task_to_prompt_name map to a known task or prompt type. + + A key is valid if: + + 1. It is a valid task name; or + 2. It is a valid task type; or + 3. It is a valid prompt type; or + 4. It is a compound key of the form "{task_name}-{prompt_type}" where task_name is a valid task type or task + name and prompt_type is a valid prompt type. + + See the + [MTEB docs](https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts) + for a complete description of the order or precedence for these keys when running an evaluation. + + Arguments: + task_to_prompt: The dictionary of prompts. + raise_for_invalid_keys: If True, raise an error when an invalid key is encountered, otherwise return the + list of error messages along with a filtered dictionary of prompts with valid keys. Defaults to True + for backward compatibility. + + Returns: + * None if `task_to_prompt` is None or empty; + * Only a dictionary of validated prompts if `raise_for_invalid_keys` is `True`; or + * A tuple continaing the filtered dictionary of valid prompts and the set of error messages for the + invalid prompts `raise_for_invalid` is `False` + + Raises: + KeyError: If any invlaid keys are encountered and `raise_for_invalid_keys` is `True`, this function will + raise a single `KeyError` contianing the + """ + if not task_to_prompt: + return None if raise_for_invalid_keys else (None, None) + task_types = get_args(TASK_TYPE) prompt_types = [e.value for e in PromptType] - for task_name in task_to_prompt_name: - if "-" in task_name and task_name.endswith( - (f"-{PromptType.query.value}", f"-{PromptType.document.value}") - ): - task_name, prompt_type = task_name.rsplit("-", 1) - if prompt_type not in prompt_types: - msg = f"Prompt type {prompt_type} is not valid. Valid prompt types are {prompt_types}" - logger.warning(msg) - raise KeyError(msg) + valid_keys_msg = f"Valid keys are task types [{task_types}], prompt types [{prompt_types}], and task names" + valid_prompt_type_endings = tuple( + [f"-{prompt_type}" for prompt_type in prompt_types] + ) + + invalid_keys: set[str] = set() + invalid_task_messages: set[str] = set() + + for task_key in task_to_prompt: + # task_key may be a compound key of the form "{task_name}-{prompt_type}". A task_name may contain a "-" + # character (this occurs in ~12% of task names), so rsplit is used to separate a valid prompt_type postfix + # from the unvalidated task_name. + if task_key.endswith(valid_prompt_type_endings): + task_name = task_key.rsplit("-", 1)[0] + else: + task_name = task_key + if task_name not in task_types and task_name not in prompt_types: - task = mteb.get_task(task_name=task_name) - if not task: - msg = f"Task name {task_name} is not valid. Valid task names are task types [{task_types}], prompt types [{prompt_types}] and task names" + try: + mteb.get_task(task_name=task_name) + except KeyError: + msg = f"Task name {task_name} is not valid. {valid_keys_msg}" logger.warning(msg) - raise KeyError(msg) - return task_to_prompt_name + invalid_task_messages.add(msg) + invalid_keys.add(task_key) + + if raise_for_invalid_keys and invalid_task_messages: + raise KeyError(invalid_task_messages) + elif raise_for_invalid_keys: + return task_to_prompt + else: + return { + k: v for k, v in task_to_prompt.items() if k not in invalid_keys + }, tuple(invalid_task_messages) @staticmethod def get_instruction( diff --git a/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py b/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py index 4ca0056cd7..be5377ea73 100644 --- a/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py +++ b/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py @@ -24,7 +24,13 @@ ) def test_benchmark_sentence_transformer(task: str | AbsTask, model_name: str): """Test that a task can be fetched and run""" - if isinstance(model_name, str): - model = SentenceTransformer(model_name) + model = SentenceTransformer(model_name) + # Prior to https://github.com/embeddings-benchmark/mteb/pull/3079 the + # SentenceTransformerWrapper would set the model's prompts to None because + # the mock tasks are not in the MTEB task registry. The linked PR changes + # this behavior and keeps the prompts as configured by the model, so this + # test now sets the prompts to None explicitly to preserve the legacy + # behavior and focus the test on the tasks instead of the prompts. + model.prompts = None eval = MTEB(tasks=[task]) eval.run(model, output_folder="tests/results", overwrite_results=True) diff --git a/tests/test_reproducible_workflow.py b/tests/test_reproducible_workflow.py index 738392c623..48f64e6496 100644 --- a/tests/test_reproducible_workflow.py +++ b/tests/test_reproducible_workflow.py @@ -14,9 +14,16 @@ logging.basicConfig(level=logging.INFO) -@pytest.mark.parametrize("task_name", ["BornholmBitextMining"]) -@pytest.mark.parametrize("model_name", ["sentence-transformers/all-MiniLM-L6-v2"]) -@pytest.mark.parametrize("model_revision", ["8b3219a92973c328a8e22fadcfa821b5dc75636a"]) +@pytest.mark.parametrize( + "task_name, model_name, model_revision", + [ + ( + "BornholmBitextMining", + "sentence-transformers/all-MiniLM-L6-v2", + "8b3219a92973c328a8e22fadcfa821b5dc75636a", + ), + ], +) def test_reproducibility_workflow(task_name: str, model_name: str, model_revision: str): """Test that a model and a task can be fetched and run in a reproducible fashion.""" model_meta = mteb.get_model_meta(model_name, revision=model_revision) @@ -67,11 +74,51 @@ def test_validate_task_to_prompt_name(task_name: str | mteb.AbsTask): Wrapper.validate_task_to_prompt_name(model_prompts) -def test_validate_task_to_prompt_name_fail(): - with pytest.raises(KeyError): - Wrapper.validate_task_to_prompt_name( - {"task_name": "prompt_name", "task_name-query": "prompt_name"} - ) +@pytest.mark.parametrize("raise_for_invalid_keys", (True, False)) +def test_validate_task_to_prompt_name_for_none(raise_for_invalid_keys: bool): + result = Wrapper.validate_task_to_prompt_name( + None, raise_for_invalid_keys=raise_for_invalid_keys + ) + assert result is None if raise_for_invalid_keys else (None, None) + +@pytest.mark.parametrize( + "task_prompt_dict", + [ + {"task_name": "prompt_name"}, + {"task_name-query": "prompt_name"}, + {"task_name-task_name": "prompt_name"}, + ], +) +def test_validate_task_to_prompt_name_fails_and_raises( + task_prompt_dict: dict[str, str], +): with pytest.raises(KeyError): - Wrapper.validate_task_to_prompt_name({"task_name-task_name": "prompt_name"}) + Wrapper.validate_task_to_prompt_name(task_prompt_dict) + + +@pytest.mark.parametrize( + "task_prompt_dict, expected_valid, expected_invalid", + [ + ({"task_name": "prompt_name"}, 0, 1), + ({"task_name-query": "prompt_name"}, 0, 1), + ( + { + "task_name-query": "prompt_name", + "query": "prompt_name", + "Retrieval": "prompt_name", + }, + 2, + 1, + ), + ({"task_name-task_name": "prompt_name"}, 0, 1), + ], +) +def test_validate_task_to_prompt_name_filters_and_reports( + task_prompt_dict: dict[str, str], expected_valid: int, expected_invalid: int +): + valid, invalid = Wrapper.validate_task_to_prompt_name( + task_prompt_dict, raise_for_invalid_keys=False + ) + assert len(valid) == expected_valid + assert len(invalid) == expected_invalid