Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,13 @@ def filter_languages(
self.hf_subsets = subsets_to_keep
return self

@property
def is_aggregate(
self,
) -> bool: # Overrided by subclasses (AbsTaskAggregate) that are aggregate
"""Whether the task is aggregate. Subclasses that are aggregate should override this with `True`."""
return False

@property
def eval_splits(self) -> list[str]:
if self._eval_splits:
Expand Down
4 changes: 4 additions & 0 deletions mteb/abstasks/aggregated_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def _calculate_metrics_from_split(
"Aggregate tasks does not implement a _calculate_metrics_from_split. Instead use the individual tasks."
)

@property
def is_aggregate(self): # Overrides the is_aggregate method on AbsTask
return True

@property
def eval_splits(self) -> list[str]:
if self._eval_splits:
Expand Down
3 changes: 1 addition & 2 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sentence_transformers import CrossEncoder, SentenceTransformer

from mteb.abstasks.AbsTask import ScoresDict
from mteb.abstasks.aggregated_task import AbsTaskAggregate
from mteb.encoder_interface import Encoder
from mteb.model_meta import ModelMeta
from mteb.models import model_meta_from_sentence_transformers
Expand Down Expand Up @@ -467,7 +466,7 @@ def run(
f"\n\n********************** Evaluating {task.metadata.name} **********************"
)

if isinstance(task, AbsTaskAggregate):
if task.is_aggregate:
self_ = MTEB(tasks=task.metadata.tasks)
task_results = self_.run(
model,
Expand Down
13 changes: 13 additions & 0 deletions mteb/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ def filter_tasks_by_modalities(
return [t for t in tasks if _modalities.intersection(t.modalities)]


def filter_aggregate_tasks(tasks: list[AbsTask]) -> list[AbsTask]:
"""Returns input tasks that are *not* aggregate.

Args:
tasks: A list of tasks to filter.
"""
return [t for t in tasks if not t.is_aggregate]


class MTEBTasks(tuple):
def __repr__(self) -> str:
return "MTEBTasks" + super().__repr__()
Expand Down Expand Up @@ -278,6 +287,7 @@ def get_tasks(
exclusive_language_filter: bool = False,
modalities: list[MODALITIES] | None = None,
exclusive_modality_filter: bool = False,
exclude_aggregate: bool = False,
) -> MTEBTasks:
"""Get a list of tasks based on the specified filters.

Expand All @@ -300,6 +310,7 @@ def get_tasks(
exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
task's modalities and ALL task modalities are in filter modalities (exact match).
If False, keep tasks if _any_ of the task's modalities match the filter modalities.
exclude_aggregate: If True, exclude aggregate tasks. If False, both aggregate and non-aggregate tasks are returned.

Returns:
A list of all initialized tasks objects which pass all of the filters (AND operation).
Expand Down Expand Up @@ -350,6 +361,8 @@ def get_tasks(
_tasks = filter_tasks_by_modalities(
_tasks, modalities, exclusive_modality_filter
)
if exclude_aggregate:
_tasks = filter_aggregate_tasks(_tasks)

return MTEBTasks(_tasks)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_get_task(
@pytest.mark.parametrize("exclude_superseded_datasets", [True, False])
@pytest.mark.parametrize("modalities", [["text"], ["image"], ["text", "image"], None])
@pytest.mark.parametrize("exclusive_modality_filter", [True, False])
@pytest.mark.parametrize("exclude_aggregate", [True, False])
def test_get_tasks(
languages: list[str],
script: list[str],
Expand All @@ -80,6 +81,7 @@ def test_get_tasks(
exclude_superseded_datasets: bool,
modalities: list[MODALITIES] | None,
exclusive_modality_filter: bool,
exclude_aggregate: bool,
):
tasks = mteb.get_tasks(
languages=languages,
Expand All @@ -89,6 +91,7 @@ def test_get_tasks(
exclude_superseded=exclude_superseded_datasets,
modalities=modalities,
exclusive_modality_filter=exclusive_modality_filter,
exclude_aggregate=exclude_aggregate,
)

for task in tasks:
Expand All @@ -110,6 +113,9 @@ def test_get_tasks(
assert set(task.modalities) == set(modalities)
else:
assert any(mod in task.modalities for mod in modalities)
if exclude_aggregate:
# Aggregate tasks should be excluded
assert not task.is_aggregate


def test_get_tasks_filtering():
Expand Down
7 changes: 7 additions & 0 deletions tests/test_tasks/test_all_abstasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,10 @@ def test_superseded_dataset_exists():
assert task.superseded_by in TASKS_REGISTRY, (
f"{task} is superseded by {task.superseded_by} but {task.superseded_by} is not in the TASKS_REGISTRY"
)


def test_is_aggregate_property_correct():
tasks = mteb.get_tasks()

for task in tasks:
assert task.is_aggregate == isinstance(task, AbsTaskAggregate)