diff --git a/mteb/abstasks/AbsTask.py b/mteb/abstasks/AbsTask.py index b81c860b83..93e990c76b 100644 --- a/mteb/abstasks/AbsTask.py +++ b/mteb/abstasks/AbsTask.py @@ -347,6 +347,13 @@ def filter_languages( self.hf_subsets = subsets_to_keep return self + @property + def is_aggregate( + self, + ) -> bool: # Overrided by subclasses (AbsTaskAggregate) that are aggregate + """Whether the task is aggregate. Subclasses that are aggregate should override this with `True`.""" + return False + @property def eval_splits(self) -> list[str]: if self._eval_splits: diff --git a/mteb/abstasks/aggregated_task.py b/mteb/abstasks/aggregated_task.py index 4c79db01ae..f43c19f237 100644 --- a/mteb/abstasks/aggregated_task.py +++ b/mteb/abstasks/aggregated_task.py @@ -154,6 +154,10 @@ def _calculate_metrics_from_split( "Aggregate tasks does not implement a _calculate_metrics_from_split. Instead use the individual tasks." ) + @property + def is_aggregate(self): # Overrides the is_aggregate method on AbsTask + return True + @property def eval_splits(self) -> list[str]: if self._eval_splits: diff --git a/mteb/evaluation/MTEB.py b/mteb/evaluation/MTEB.py index a6bacc189e..f1ae39ef9c 100644 --- a/mteb/evaluation/MTEB.py +++ b/mteb/evaluation/MTEB.py @@ -17,7 +17,6 @@ from sentence_transformers import CrossEncoder, SentenceTransformer from mteb.abstasks.AbsTask import ScoresDict -from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.encoder_interface import Encoder from mteb.model_meta import ModelMeta from mteb.models import model_meta_from_sentence_transformers @@ -467,7 +466,7 @@ def run( f"\n\n********************** Evaluating {task.metadata.name} **********************" ) - if isinstance(task, AbsTaskAggregate): + if task.is_aggregate: self_ = MTEB(tasks=task.metadata.tasks) task_results = self_.run( model, diff --git a/mteb/overview.py b/mteb/overview.py index 03ad0e67ba..461221e7fa 100644 --- a/mteb/overview.py +++ b/mteb/overview.py @@ -133,6 +133,15 @@ def filter_tasks_by_modalities( return [t for t in tasks if _modalities.intersection(t.modalities)] +def filter_aggregate_tasks(tasks: list[AbsTask]) -> list[AbsTask]: + """Returns input tasks that are *not* aggregate. + + Args: + tasks: A list of tasks to filter. + """ + return [t for t in tasks if not t.is_aggregate] + + class MTEBTasks(tuple): def __repr__(self) -> str: return "MTEBTasks" + super().__repr__() @@ -278,6 +287,7 @@ def get_tasks( exclusive_language_filter: bool = False, modalities: list[MODALITIES] | None = None, exclusive_modality_filter: bool = False, + exclude_aggregate: bool = False, ) -> MTEBTasks: """Get a list of tasks based on the specified filters. @@ -300,6 +310,7 @@ def get_tasks( exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the task's modalities and ALL task modalities are in filter modalities (exact match). If False, keep tasks if _any_ of the task's modalities match the filter modalities. + exclude_aggregate: If True, exclude aggregate tasks. If False, both aggregate and non-aggregate tasks are returned. Returns: A list of all initialized tasks objects which pass all of the filters (AND operation). @@ -350,6 +361,8 @@ def get_tasks( _tasks = filter_tasks_by_modalities( _tasks, modalities, exclusive_modality_filter ) + if exclude_aggregate: + _tasks = filter_aggregate_tasks(_tasks) return MTEBTasks(_tasks) diff --git a/tests/test_overview.py b/tests/test_overview.py index 4486bc1136..801929817c 100644 --- a/tests/test_overview.py +++ b/tests/test_overview.py @@ -72,6 +72,7 @@ def test_get_task( @pytest.mark.parametrize("exclude_superseded_datasets", [True, False]) @pytest.mark.parametrize("modalities", [["text"], ["image"], ["text", "image"], None]) @pytest.mark.parametrize("exclusive_modality_filter", [True, False]) +@pytest.mark.parametrize("exclude_aggregate", [True, False]) def test_get_tasks( languages: list[str], script: list[str], @@ -80,6 +81,7 @@ def test_get_tasks( exclude_superseded_datasets: bool, modalities: list[MODALITIES] | None, exclusive_modality_filter: bool, + exclude_aggregate: bool, ): tasks = mteb.get_tasks( languages=languages, @@ -89,6 +91,7 @@ def test_get_tasks( exclude_superseded=exclude_superseded_datasets, modalities=modalities, exclusive_modality_filter=exclusive_modality_filter, + exclude_aggregate=exclude_aggregate, ) for task in tasks: @@ -110,6 +113,9 @@ def test_get_tasks( assert set(task.modalities) == set(modalities) else: assert any(mod in task.modalities for mod in modalities) + if exclude_aggregate: + # Aggregate tasks should be excluded + assert not task.is_aggregate def test_get_tasks_filtering(): diff --git a/tests/test_tasks/test_all_abstasks.py b/tests/test_tasks/test_all_abstasks.py index b192b3b334..55e33e994a 100644 --- a/tests/test_tasks/test_all_abstasks.py +++ b/tests/test_tasks/test_all_abstasks.py @@ -106,3 +106,10 @@ def test_superseded_dataset_exists(): assert task.superseded_by in TASKS_REGISTRY, ( f"{task} is superseded by {task.superseded_by} but {task.superseded_by} is not in the TASKS_REGISTRY" ) + + +def test_is_aggregate_property_correct(): + tasks = mteb.get_tasks() + + for task in tasks: + assert task.is_aggregate == isinstance(task, AbsTaskAggregate)