Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions mteb/models/bedrock_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@ def __init__(
self._provider = provider.lower()

if self._provider == "cohere":
self.model_prompts = (
self.validate_task_to_prompt_name(model_prompts)
if model_prompts
else None
)
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)
self._max_batch_size = 96
self._max_sequence_length = max_tokens * 4
else:
Expand Down
4 changes: 1 addition & 3 deletions mteb/models/cohere_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ def __init__(
) -> None:
self.model_name = model_name
self.sep = sep
self.model_prompts = (
self.validate_task_to_prompt_name(model_prompts) if model_prompts else None
)
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)

def _embed(
self,
Expand Down
16 changes: 5 additions & 11 deletions mteb/models/colbert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,11 @@ def __init__(

self.model_name = model_name
self.model = colbert_model.ColBERT(self.model_name, revision=revision, **kwargs)
if (
model_prompts is None
and hasattr(self.model, "prompts")
and len(self.model.prompts) > 0
):
try:
model_prompts = self.validate_task_to_prompt_name(self.model.prompts)
except ValueError:
model_prompts = None
elif model_prompts is not None and hasattr(self.model, "prompts"):
logger.info(f"Model prompts will be overwritten with {model_prompts}")
built_in_prompts = getattr(self.model, "prompts", None)
if built_in_prompts and not model_prompts:
model_prompts = built_in_prompts
elif model_prompts and built_in_prompts:
logger.info(f"Model.prompts will be overwritten with {model_prompts}")
self.model.prompts = model_prompts
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)

Expand Down
4 changes: 1 addition & 3 deletions mteb/models/google_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def __init__(
**kwargs,
) -> None:
self.model_name = model_name
self.model_prompts = (
self.validate_task_to_prompt_name(model_prompts) if model_prompts else None
)
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)

def _embed(
self,
Expand Down
4 changes: 1 addition & 3 deletions mteb/models/llm2vec_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def __init__(

extra_kwargs["attn_implementation"] = "flash_attention_2"

self.model_prompts = (
self.validate_task_to_prompt_name(model_prompts) if model_prompts else None
)
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)

if device:
kwargs["device_map"] = device
Expand Down
4 changes: 1 addition & 3 deletions mteb/models/repllama_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def __init__(
# set the max_length for the evals as they did, although the model can handle longer
self.model.config.max_length = 512
self.tokenizer.model_max_length = 512
self.model_prompts = (
self.validate_task_to_prompt_name(model_prompts) if model_prompts else None
)
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)

def create_batch_dict(self, tokenizer, input_texts):
max_length = self.model.config.max_length
Expand Down
49 changes: 28 additions & 21 deletions mteb/models/sentence_transformer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,35 @@ def __init__(
self.model = model

if (
model_prompts is None
and hasattr(self.model, "prompts")
and len(self.model.prompts) > 0
):
try:
model_prompts = self.validate_task_to_prompt_name(self.model.prompts)

if (
len(self.model.prompts) == 2
and self.model.prompts.get("query", "") == ""
and self.model.prompts.get("document", "") == ""
):
model_prompts = None
except KeyError:
model_prompts = None
logger.warning(
"Model prompts are not in the expected format. Ignoring them."
)
elif model_prompts is not None and hasattr(self.model, "prompts"):
logger.info(f"Model prompts will be overwritten with {model_prompts}")
built_in_prompts := getattr(self.model, "prompts", None)
) and not model_prompts:
model_prompts = built_in_prompts
elif model_prompts and built_in_prompts:
logger.warning(f"Model prompts will be overwritten with {model_prompts}")
self.model.prompts = model_prompts
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)

self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name(
model_prompts, raise_for_invalid_keys=False
)

if invalid_prompts:
invalid_prompts = "\n".join(invalid_prompts)
logger.warning(
f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
)

if (
self.model_prompts
and len(self.model_prompts) <= 2
and (
PromptType.query.value not in self.model_prompts
or PromptType.document.value not in self.model_prompts
)
):
logger.warning(
"SentenceTransformers that use prompts most often need to be configured with at least 'query' and"
f" 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
)

if isinstance(self.model, CrossEncoder):
self.predict = self._predict
Expand Down
4 changes: 1 addition & 3 deletions mteb/models/voyage_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def __init__(
self._embed_func = rate_limit(max_rpm)(token_limit(max_tpm)(self._client.embed))
self._model_name = model_name
self._max_tpm = max_tpm
self.model_prompts = (
self.validate_task_to_prompt_name(model_prompts) if model_prompts else None
)
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)

def encode(
self,
Expand Down
105 changes: 86 additions & 19 deletions mteb/models/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
from typing import Callable, get_args
from collections.abc import Callable, Sequence
from typing import Literal, get_args, overload

import mteb
from mteb.abstasks.TaskMetadata import TASK_TYPE
Expand Down Expand Up @@ -65,29 +66,95 @@ def get_prompt_name(
return None

@staticmethod
@overload
def validate_task_to_prompt_name(
task_to_prompt_name: dict[str, str] | None,
) -> dict[str, str] | None:
if task_to_prompt_name is None:
return task_to_prompt_name
task_to_prompt: dict[str, str] | None,
raise_for_invalid_keys: Literal[True] = True,
) -> dict[str, str] | None: ...

@staticmethod
@overload
def validate_task_to_prompt_name(
task_to_prompt: dict[str, str] | None,
raise_for_invalid_keys: Literal[False] = False,
) -> tuple[dict[str, str], Sequence[str]] | tuple[None, None]: ...

@staticmethod
def validate_task_to_prompt_name(
task_to_prompt: dict[str, str] | None,
raise_for_invalid_keys: bool = True,
) -> (
dict[str, str] | tuple[dict[str, str], Sequence[str]] | tuple[None, None] | None
):
"""Validates that the keys in task_to_prompt_name map to a known task or prompt type.

A key is valid if:

1. It is a valid task name; or
2. It is a valid task type; or
3. It is a valid prompt type; or
4. It is a compound key of the form "{task_name}-{prompt_type}" where task_name is a valid task type or task
name and prompt_type is a valid prompt type.

See the
[MTEB docs](https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts)
for a complete description of the order or precedence for these keys when running an evaluation.

Arguments:
task_to_prompt: The dictionary of prompts.
raise_for_invalid_keys: If True, raise an error when an invalid key is encountered, otherwise return the
list of error messages along with a filtered dictionary of prompts with valid keys. Defaults to True
for backward compatibility.

Returns:
* None if `task_to_prompt` is None or empty;
* Only a dictionary of validated prompts if `raise_for_invalid_keys` is `True`; or
* A tuple continaing the filtered dictionary of valid prompts and the set of error messages for the
invalid prompts `raise_for_invalid` is `False`

Raises:
KeyError: If any invlaid keys are encountered and `raise_for_invalid_keys` is `True`, this function will
raise a single `KeyError` contianing the
"""
if not task_to_prompt:
return None if raise_for_invalid_keys else (None, None)

task_types = get_args(TASK_TYPE)
prompt_types = [e.value for e in PromptType]
for task_name in task_to_prompt_name:
if "-" in task_name and task_name.endswith(
(f"-{PromptType.query.value}", f"-{PromptType.document.value}")
):
task_name, prompt_type = task_name.rsplit("-", 1)
if prompt_type not in prompt_types:
msg = f"Prompt type {prompt_type} is not valid. Valid prompt types are {prompt_types}"
logger.warning(msg)
raise KeyError(msg)
valid_keys_msg = f"Valid keys are task types [{task_types}], prompt types [{prompt_types}], and task names"
valid_prompt_type_endings = tuple(
[f"-{prompt_type}" for prompt_type in prompt_types]
)

invalid_keys: set[str] = set()
invalid_task_messages: set[str] = set()

for task_key in task_to_prompt:
# task_key may be a compound key of the form "{task_name}-{prompt_type}". A task_name may contain a "-"
# character (this occurs in ~12% of task names), so rsplit is used to separate a valid prompt_type postfix
# from the unvalidated task_name.
if task_key.endswith(valid_prompt_type_endings):
task_name = task_key.rsplit("-", 1)[0]
else:
task_name = task_key

if task_name not in task_types and task_name not in prompt_types:
task = mteb.get_task(task_name=task_name)
if not task:
msg = f"Task name {task_name} is not valid. Valid task names are task types [{task_types}], prompt types [{prompt_types}] and task names"
try:
mteb.get_task(task_name=task_name)
except KeyError:
msg = f"Task name {task_name} is not valid. {valid_keys_msg}"
logger.warning(msg)
raise KeyError(msg)
return task_to_prompt_name
invalid_task_messages.add(msg)
invalid_keys.add(task_key)

if raise_for_invalid_keys and invalid_task_messages:
raise KeyError(invalid_task_messages)
elif raise_for_invalid_keys:
return task_to_prompt
else:
return {
k: v for k, v in task_to_prompt.items() if k not in invalid_keys
}, tuple(invalid_task_messages)

@staticmethod
def get_instruction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
)
def test_benchmark_sentence_transformer(task: str | AbsTask, model_name: str):
"""Test that a task can be fetched and run"""
if isinstance(model_name, str):
model = SentenceTransformer(model_name)
model = SentenceTransformer(model_name)
# Prior to https://github.com/embeddings-benchmark/mteb/pull/3079 the
# SentenceTransformerWrapper would set the model's prompts to None because
# the mock tasks are not in the MTEB task registry. The linked PR changes
# this behavior and keeps the prompts as configured by the model, so this
# test now sets the prompts to None explicitly to preserve the legacy
# behavior and focus the test on the tasks instead of the prompts.
model.prompts = None
eval = MTEB(tasks=[task])
eval.run(model, output_folder="tests/results", overwrite_results=True)
65 changes: 56 additions & 9 deletions tests/test_reproducible_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@
logging.basicConfig(level=logging.INFO)


@pytest.mark.parametrize("task_name", ["BornholmBitextMining"])
@pytest.mark.parametrize("model_name", ["sentence-transformers/all-MiniLM-L6-v2"])
@pytest.mark.parametrize("model_revision", ["8b3219a92973c328a8e22fadcfa821b5dc75636a"])
@pytest.mark.parametrize(
"task_name, model_name, model_revision",
[
(
"BornholmBitextMining",
"sentence-transformers/all-MiniLM-L6-v2",
"8b3219a92973c328a8e22fadcfa821b5dc75636a",
),
],
)
def test_reproducibility_workflow(task_name: str, model_name: str, model_revision: str):
"""Test that a model and a task can be fetched and run in a reproducible fashion."""
model_meta = mteb.get_model_meta(model_name, revision=model_revision)
Expand Down Expand Up @@ -67,11 +74,51 @@ def test_validate_task_to_prompt_name(task_name: str | mteb.AbsTask):
Wrapper.validate_task_to_prompt_name(model_prompts)


def test_validate_task_to_prompt_name_fail():
with pytest.raises(KeyError):
Wrapper.validate_task_to_prompt_name(
{"task_name": "prompt_name", "task_name-query": "prompt_name"}
)
@pytest.mark.parametrize("raise_for_invalid_keys", (True, False))
def test_validate_task_to_prompt_name_for_none(raise_for_invalid_keys: bool):
result = Wrapper.validate_task_to_prompt_name(
None, raise_for_invalid_keys=raise_for_invalid_keys
)
assert result is None if raise_for_invalid_keys else (None, None)


@pytest.mark.parametrize(
"task_prompt_dict",
[
{"task_name": "prompt_name"},
{"task_name-query": "prompt_name"},
{"task_name-task_name": "prompt_name"},
],
)
def test_validate_task_to_prompt_name_fails_and_raises(
task_prompt_dict: dict[str, str],
):
with pytest.raises(KeyError):
Wrapper.validate_task_to_prompt_name({"task_name-task_name": "prompt_name"})
Wrapper.validate_task_to_prompt_name(task_prompt_dict)


@pytest.mark.parametrize(
"task_prompt_dict, expected_valid, expected_invalid",
[
({"task_name": "prompt_name"}, 0, 1),
({"task_name-query": "prompt_name"}, 0, 1),
(
{
"task_name-query": "prompt_name",
"query": "prompt_name",
"Retrieval": "prompt_name",
},
2,
1,
),
({"task_name-task_name": "prompt_name"}, 0, 1),
],
)
def test_validate_task_to_prompt_name_filters_and_reports(
task_prompt_dict: dict[str, str], expected_valid: int, expected_invalid: int
):
valid, invalid = Wrapper.validate_task_to_prompt_name(
task_prompt_dict, raise_for_invalid_keys=False
)
assert len(valid) == expected_valid
assert len(invalid) == expected_invalid
Loading