diff --git a/docs/api/task.md b/docs/api/task.md index 83a6ccc32b..81ed4481a8 100644 --- a/docs/api/task.md +++ b/docs/api/task.md @@ -14,6 +14,8 @@ A task is an implementation of a dataset for evaluation. It could, for instance, :::mteb.get_task +:::mteb.filter_tasks + ## Metadata Each task also contains extensive metadata. We annotate this using the following object, which allows us to use [pydantic](https://docs.pydantic.dev/latest/) to validate the metadata. diff --git a/docs/overview/create_available_benchmarks.py b/docs/overview/create_available_benchmarks.py index b25716fefc..4013777786 100644 --- a/docs/overview/create_available_benchmarks.py +++ b/docs/overview/create_available_benchmarks.py @@ -4,7 +4,7 @@ from typing import cast import mteb -from mteb.overview import MTEBTasks +from mteb.get_tasks import MTEBTasks START_INSERT = "" END_INSERT = "" diff --git a/mteb/MTEB.py b/mteb/MTEB.py index 34708976b6..bdab22a731 100644 --- a/mteb/MTEB.py +++ b/mteb/MTEB.py @@ -122,7 +122,7 @@ def _display_tasks(self, task_list: Iterable[AbsTask], name: str | None = None): def mteb_benchmarks(self): """Get all benchmarks available in the MTEB.""" - from mteb.overview import MTEBTasks + from mteb.get_tasks import MTEBTasks # get all the MTEB specific benchmarks: sorted_mteb_benchmarks = sorted( diff --git a/mteb/__init__.py b/mteb/__init__.py index 35cdfbfe2e..d65f8a991d 100644 --- a/mteb/__init__.py +++ b/mteb/__init__.py @@ -3,11 +3,12 @@ from mteb.abstasks import AbsTask from mteb.abstasks.task_metadata import TaskMetadata from mteb.evaluate import evaluate +from mteb.filter_tasks import filter_tasks +from mteb.get_tasks import get_task, get_tasks from mteb.load_results import load_results from mteb.models import Encoder, SentenceTransformerEncoderWrapper from mteb.models.get_model_meta import get_model, get_model_meta, get_model_metas from mteb.MTEB import MTEB -from mteb.overview import get_task, get_tasks from mteb.results import BenchmarkResults, TaskResult from .benchmarks.benchmark import Benchmark @@ -25,6 +26,7 @@ "TaskMetadata", "TaskResult", "evaluate", + "filter_tasks", "get_benchmark", "get_benchmarks", "get_model", diff --git a/mteb/abstasks/abstask.py b/mteb/abstasks/abstask.py index d994634610..b120fdc77d 100644 --- a/mteb/abstasks/abstask.py +++ b/mteb/abstasks/abstask.py @@ -432,32 +432,6 @@ def filter_eval_splits(self, eval_splits: list[str] | None) -> Self: self._eval_splits = eval_splits return self - def filter_modalities( - self, modalities: list[str] | None, exclusive_modality_filter: bool = False - ) -> Self: - """Filter the modalities of the task. - - Args: - modalities: A list of modalities to filter by. If None, the task is returned unchanged. - exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the - task's modalities and ALL task modalities are in filter modalities (exact match). - If False, keep tasks if _any_ of the task's modalities match the filter modalities. - - Returns: - The filtered task - """ - if modalities is None: - return self - filter_modalities_set = set(modalities) - task_modalities_set = set(self.modalities) - if exclusive_modality_filter: - if not (filter_modalities_set == task_modalities_set): - self.metadata.modalities = [] - else: - if not filter_modalities_set.intersection(task_modalities_set): - self.metadata.modalities = [] - return self - def filter_languages( self, languages: list[str] | None, diff --git a/mteb/benchmarks/_create_table.py b/mteb/benchmarks/_create_table.py index 6efdf80c75..5e63006e89 100644 --- a/mteb/benchmarks/_create_table.py +++ b/mteb/benchmarks/_create_table.py @@ -6,7 +6,7 @@ import pandas as pd import mteb -from mteb.overview import get_task, get_tasks +from mteb.get_tasks import get_task, get_tasks from mteb.results.benchmark_results import BenchmarkResults diff --git a/mteb/benchmarks/benchmarks/benchmarks.py b/mteb/benchmarks/benchmarks/benchmarks.py index 0ebf56a6c3..e29eab05b3 100644 --- a/mteb/benchmarks/benchmarks/benchmarks.py +++ b/mteb/benchmarks/benchmarks/benchmarks.py @@ -1,5 +1,5 @@ from mteb.benchmarks.benchmark import Benchmark, HUMEBenchmark, MIEBBenchmark -from mteb.overview import MTEBTasks, get_task, get_tasks +from mteb.get_tasks import MTEBTasks, get_task, get_tasks MMTEB_CITATION = r"""@article{enevoldsen2025mmtebmassivemultilingualtext, author = {Kenneth Enevoldsen and Isaac Chung and Imene Kerboua and Márton Kardos and Ashwin Mathur and David Stap and Jay Gala and Wissam Siblini and Dominik Krzemiński and Genta Indra Winata and Saba Sturua and Saiteja Utpala and Mathieu Ciancone and Marion Schaeffer and Gabriel Sequeira and Diganta Misra and Shreeya Dhakal and Jonathan Rystrøm and Roman Solomatin and Ömer Çağatan and Akash Kundu and Martin Bernstorff and Shitao Xiao and Akshita Sukhlecha and Bhavish Pahwa and Rafał Poświata and Kranthi Kiran GV and Shawon Ashraf and Daniel Auras and Björn Plüster and Jan Philipp Harries and Loïc Magne and Isabelle Mohr and Mariya Hendriksen and Dawei Zhu and Hippolyte Gisserot-Boukhlef and Tom Aarsen and Jan Kostkan and Konrad Wojtasik and Taemin Lee and Marek Šuppa and Crystina Zhang and Roberta Rocca and Mohammed Hamdy and Andrianos Michail and John Yang and Manuel Faysse and Aleksei Vatolin and Nandan Thakur and Manan Dey and Dipam Vasani and Pranjal Chitale and Simone Tedeschi and Nguyen Tai and Artem Snegirev and Michael Günther and Mengzhou Xia and Weijia Shi and Xing Han Lù and Jordan Clive and Gayatri Krishnakumar and Anna Maksimova and Silvan Wehrli and Maria Tikhonova and Henil Panchal and Aleksandr Abramov and Malte Ostendorff and Zheng Liu and Simon Clematide and Lester James Miranda and Alena Fenogenova and Guangyu Song and Ruqiya Bin Safi and Wen-Ding Li and Alessia Borghini and Federico Cassano and Hongjin Su and Jimmy Lin and Howard Yen and Lasse Hansen and Sara Hooker and Chenghao Xiao and Vaibhav Adlakha and Orion Weller and Siva Reddy and Niklas Muennighoff}, diff --git a/mteb/benchmarks/benchmarks/rteb_benchmarks.py b/mteb/benchmarks/benchmarks/rteb_benchmarks.py index ae983a96e6..a13c5088c0 100644 --- a/mteb/benchmarks/benchmarks/rteb_benchmarks.py +++ b/mteb/benchmarks/benchmarks/rteb_benchmarks.py @@ -2,7 +2,7 @@ from mteb.benchmarks.benchmark import RtebBenchmark -from mteb.overview import get_tasks +from mteb.get_tasks import get_tasks RTEB_CITATION = r"""@article{rteb2025, author = {Liu, Frank and Enevoldsen, Kenneth and Solomatin, Roman and Chung, Isaac and Aarsen, Tom and Fődi, Zoltán}, diff --git a/mteb/cli/_display_tasks.py b/mteb/cli/_display_tasks.py index 5129545f46..4b4fa1268d 100644 --- a/mteb/cli/_display_tasks.py +++ b/mteb/cli/_display_tasks.py @@ -2,7 +2,7 @@ from mteb.abstasks import AbsTask from mteb.benchmarks import Benchmark -from mteb.overview import MTEBTasks +from mteb.get_tasks import MTEBTasks def _display_benchmarks(benchmarks: Sequence[Benchmark]) -> None: diff --git a/mteb/filter_tasks.py b/mteb/filter_tasks.py new file mode 100644 index 0000000000..b95a755ab3 --- /dev/null +++ b/mteb/filter_tasks.py @@ -0,0 +1,172 @@ +"""This script contains functions that are used to get an overview of the MTEB benchmark.""" + +import logging +from collections.abc import Sequence +from typing import overload + +from mteb.abstasks import ( + AbsTask, +) +from mteb.abstasks.task_metadata import TaskCategory, TaskDomain, TaskType +from mteb.languages import ( + ISO_TO_LANGUAGE, + ISO_TO_SCRIPT, +) +from mteb.types import Modalities + +logger = logging.getLogger(__name__) + + +def _check_is_valid_script(script: str) -> None: + if script not in ISO_TO_SCRIPT: + raise ValueError( + f"Invalid script code: '{script}', you can see valid ISO 15924 codes using `from mteb.languages import ISO_TO_SCRIPT`." + ) + + +def _check_is_valid_language(lang: str) -> None: + if lang not in ISO_TO_LANGUAGE: + raise ValueError( + f"Invalid language code: '{lang}', you can see valid ISO 639-3 codes using `from mteb.languages import ISO_TO_LANGUAGE`." + ) + + +@overload +def filter_tasks( + tasks: Sequence[AbsTask], + *, + languages: list[str] | None = None, + script: list[str] | None = None, + domains: list[TaskDomain] | None = None, + task_types: list[TaskType] | None = None, # type: ignore + categories: list[TaskCategory] | None = None, + modalities: list[Modalities] | None = None, + exclusive_modality_filter: bool = False, + exclude_superseded: bool = False, + exclude_aggregate: bool = False, + exclude_private: bool = False, +) -> list[AbsTask]: ... + + +@overload +def filter_tasks( + tasks: Sequence[type[AbsTask]], + *, + languages: list[str] | None = None, + script: list[str] | None = None, + domains: list[TaskDomain] | None = None, + task_types: list[TaskType] | None = None, # type: ignore + categories: list[TaskCategory] | None = None, + modalities: list[Modalities] | None = None, + exclusive_modality_filter: bool = False, + exclude_superseded: bool = False, + exclude_aggregate: bool = False, + exclude_private: bool = False, +) -> list[type[AbsTask]]: ... + + +def filter_tasks( + tasks: Sequence[AbsTask] | Sequence[type[AbsTask]], + *, + languages: list[str] | None = None, + script: list[str] | None = None, + domains: list[TaskDomain] | None = None, + task_types: list[TaskType] | None = None, # type: ignore + categories: list[TaskCategory] | None = None, + modalities: list[Modalities] | None = None, + exclusive_modality_filter: bool = False, + exclude_superseded: bool = False, + exclude_aggregate: bool = False, + exclude_private: bool = False, +) -> list[AbsTask] | list[type[AbsTask]]: + """Filter tasks based on the specified criteria. + + Args: + tasks: A list of task names to include. If None, all tasks which pass the filters are included. If passed, other filters are ignored. + languages: A list of languages either specified as 3 letter languages codes (ISO 639-3, e.g. "eng") or as script languages codes e.g. + "eng-Latn". For multilingual tasks this will also remove languages that are not in the specified list. + script: A list of script codes (ISO 15924 codes, e.g. "Latn"). If None, all scripts are included. For multilingual tasks this will also remove scripts + that are not in the specified list. + domains: A list of task domains, e.g. "Legal", "Medical", "Fiction". + task_types: A string specifying the type of task e.g. "Classification" or "Retrieval". If None, all tasks are included. + categories: A list of task categories these include "t2t" (text to text), "t2i" (text to image). See TaskMetadata for the full list. + exclude_superseded: A boolean flag to exclude datasets which are superseded by another. + eval_splits: A list of evaluation splits to include. If None, all splits are included. + modalities: A list of modalities to include. If None, all modalities are included. + exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the + task's modalities and ALL task modalities are in filter modalities (exact match). + If False, keep tasks if _any_ of the task's modalities match the filter modalities. + exclude_aggregate: If True, exclude aggregate tasks. If False, both aggregate and non-aggregate tasks are returned. + exclude_private: If True (default), exclude private/closed datasets (is_public=False). If False, include both public and private datasets. + + Returns: + A list of tasks objects which pass all of the filters. + + Examples: + >>> text_classification_tasks = filter_tasks(my_tasks, task_types=["Classification"], modalities=["text"]) + >>> medical_tasks = filter_tasks(my_tasks, domains=["Medical"]) + >>> english_tasks = filter_tasks(my_tasks, languages=["eng"]) + >>> latin_script_tasks = filter_tasks(my_tasks, script=["Latn"]) + >>> text_image_tasks = filter_tasks(my_tasks, modalities=["text", "image"], exclusive_modality_filter=True) + + """ + langs_to_keep = None + if languages: + [_check_is_valid_language(lang) for lang in languages] + langs_to_keep = set(languages) + + script_to_keep = None + if script: + [_check_is_valid_script(s) for s in script] + script_to_keep = set(script) + + domains_to_keep = None + if domains: + domains_to_keep = set(domains) + + def _convert_to_set(domain: list[TaskDomain] | None) -> set: + return set(domain) if domain is not None else set() + + task_types_to_keep = None + if task_types: + task_types_to_keep = set(task_types) + + categories_to_keep = None + if categories: + categories_to_keep = set(categories) + + modalities_to_keep = None + if modalities: + modalities_to_keep = set(modalities) + + _tasks = [] + for t in tasks: + if langs_to_keep and not langs_to_keep.intersection(t.metadata.languages): + continue + if script_to_keep and not script_to_keep.intersection(t.metadata.scripts): + continue + if domains_to_keep and not domains_to_keep.intersection( + _convert_to_set(t.metadata.domains) + ): + continue + if task_types_to_keep and t.metadata.type not in task_types_to_keep: + continue + if categories_to_keep and t.metadata.category not in categories_to_keep: + continue + if modalities_to_keep: + if exclusive_modality_filter: + if set(t.metadata.modalities) != modalities_to_keep: + continue + else: + if not modalities_to_keep.intersection(t.metadata.modalities): + continue + if exclude_superseded and t.superseded_by is not None: + continue + if exclude_aggregate and t.is_aggregate: + continue + if exclude_private and not t.metadata.is_public: + continue + + _tasks.append(t) + + return _tasks diff --git a/mteb/overview.py b/mteb/get_tasks.py similarity index 72% rename from mteb/overview.py rename to mteb/get_tasks.py index 9c16e82dee..65ed716de0 100644 --- a/mteb/overview.py +++ b/mteb/get_tasks.py @@ -12,17 +12,14 @@ AbsTask, ) from mteb.abstasks.task_metadata import TaskCategory, TaskDomain, TaskType -from mteb.languages import ( - ISO_TO_LANGUAGE, - ISO_TO_SCRIPT, -) +from mteb.filter_tasks import filter_tasks from mteb.types import Modalities logger = logging.getLogger(__name__) # Create task registry -def _create_task_list() -> list[type[AbsTask]]: +def _gather_tasks() -> tuple[type[AbsTask], ...]: import mteb.tasks as tasks tasks = [ @@ -30,11 +27,12 @@ def _create_task_list() -> list[type[AbsTask]]: for t in tasks.__dict__.values() if isinstance(t, type) and issubclass(t, AbsTask) ] - return tasks + return tuple(tasks) -def _create_name_to_task_mapping() -> dict[str, type[AbsTask]]: - tasks = _create_task_list() +def _create_name_to_task_mapping( + tasks: Sequence[type[AbsTask]], +) -> dict[str, type[AbsTask]]: metadata_names = {} for cls in tasks: if cls.metadata.name in metadata_names: @@ -45,13 +43,12 @@ def _create_name_to_task_mapping() -> dict[str, type[AbsTask]]: return metadata_names -def _create_similar_tasks() -> dict[str, list[str]]: +def _create_similar_tasks(tasks: Sequence[type[AbsTask]]) -> dict[str, list[str]]: """Create a dictionary of similar tasks. Returns: Dict with key is parent task and value is list of similar tasks. """ - tasks = _create_task_list() similar_tasks = defaultdict(list) for task in tasks: if task.metadata.adapted_from: @@ -60,90 +57,9 @@ def _create_similar_tasks() -> dict[str, list[str]]: return similar_tasks -_TASKS_REGISTRY = _create_name_to_task_mapping() -_SIMILAR_TASKS = _create_similar_tasks() - - -def _check_is_valid_script(script: str) -> None: - if script not in ISO_TO_SCRIPT: - raise ValueError( - f"Invalid script code: '{script}', you can see valid ISO 15924 codes using `from mteb.languages import ISO_TO_SCRIPT`." - ) - - -def _check_is_valid_language(lang: str) -> None: - if lang not in ISO_TO_LANGUAGE: - raise ValueError( - f"Invalid language code: '{lang}', you can see valid ISO 639-3 codes using `from mteb.languages import ISO_TO_LANGUAGE`." - ) - - -def _filter_superseded_datasets(tasks: list[AbsTask]) -> list[AbsTask]: - return [t for t in tasks if t.superseded_by is None] - - -def _filter_tasks_by_languages( - tasks: list[AbsTask], languages: list[str] -) -> list[AbsTask]: - [_check_is_valid_language(lang) for lang in languages] - langs_to_keep = set(languages) - return [t for t in tasks if langs_to_keep.intersection(t.metadata.languages)] - - -def _filter_tasks_by_script(tasks: list[AbsTask], script: list[str]) -> list[AbsTask]: - [_check_is_valid_script(s) for s in script] - script_to_keep = set(script) - return [t for t in tasks if script_to_keep.intersection(t.metadata.scripts)] - - -def _filter_tasks_by_domains( - tasks: list[AbsTask], domains: list[TaskDomain] -) -> list[AbsTask]: - domains_to_keep = set(domains) - - def _convert_to_set(domain: list[TaskDomain] | None) -> set: - return set(domain) if domain is not None else set() - - return [ - t - for t in tasks - if domains_to_keep.intersection(_convert_to_set(t.metadata.domains)) - ] - - -def _filter_tasks_by_task_types( - tasks: list[AbsTask], task_types: list[TaskType] -) -> list[AbsTask]: - _task_types = set(task_types) - return [t for t in tasks if t.metadata.type in _task_types] - - -def _filter_task_by_categories( - tasks: list[AbsTask], categories: list[TaskCategory] -) -> list[AbsTask]: - _categories = set(categories) - return [t for t in tasks if t.metadata.category in _categories] - - -def _filter_tasks_by_modalities( - tasks: list[AbsTask], - modalities: list[Modalities], - exclude_modality_filter: bool = False, -) -> list[AbsTask]: - _modalities = set(modalities) - if exclude_modality_filter: - return [t for t in tasks if set(t.modalities) == _modalities] - else: - return [t for t in tasks if _modalities.intersection(t.modalities)] - - -def _filter_aggregate_tasks(tasks: list[AbsTask]) -> list[AbsTask]: - """Returns input tasks that are *not* aggregate. - - Args: - tasks: A list of tasks to filter. - """ - return [t for t in tasks if not t.is_aggregate] +TASK_LIST = _gather_tasks() +_TASKS_REGISTRY = _create_name_to_task_mapping(TASK_LIST) +_SIMILAR_TASKS = _create_similar_tasks(TASK_LIST) _DEFAULT_PROPRIETIES = ( @@ -299,7 +215,7 @@ def get_tasks( languages: list[str] | None = None, script: list[str] | None = None, domains: list[TaskDomain] | None = None, - task_types: list[TaskType] | None = None, + task_types: list[TaskType] | None = None, # type: ignore categories: list[TaskCategory] | None = None, exclude_superseded: bool = True, eval_splits: list[str] | None = None, @@ -312,7 +228,7 @@ def get_tasks( """Get a list of tasks based on the specified filters. Args: - tasks: A list of task names to include. If None, all tasks which pass the filters are included. + tasks: A list of task names to include. If None, all tasks which pass the filters are included. If passed, other filters are ignored. languages: A list of languages either specified as 3 letter languages codes (ISO 639-3, e.g. "eng") or as script languages codes e.g. "eng-Latn". For multilingual tasks this will also remove languages that are not in the specified list. script: A list of script codes (ISO 15924 codes, e.g. "Latn"). If None, all scripts are included. For multilingual tasks this will also remove scripts @@ -343,6 +259,10 @@ def get_tasks( >>> get_tasks(tasks=["STS22"], languages=["eng"], exclusive_language_filter=True) # don't include multilingual subsets containing English """ if tasks: + if languages or script or domains or task_types or categories: + logger.warning( + "When `tasks` is provided, other filters (languages, script, domains, task_types, categories) are ignored." + ) _tasks = [ get_task( task, @@ -350,44 +270,29 @@ def get_tasks( script, eval_splits=eval_splits, exclusive_language_filter=exclusive_language_filter, - modalities=modalities, - exclusive_modality_filter=exclusive_modality_filter, ) for task in tasks ] return MTEBTasks(_tasks) + _tasks = filter_tasks( + TASK_LIST, + languages=languages, + script=script, + domains=domains, + task_types=task_types, + categories=categories, + modalities=modalities, + exclusive_modality_filter=exclusive_modality_filter, + exclude_superseded=exclude_superseded, + exclude_aggregate=exclude_aggregate, + exclude_private=exclude_private, + ) _tasks = [ cls().filter_languages(languages, script).filter_eval_splits(eval_splits) - for cls in _create_task_list() + for cls in _tasks ] - if languages: - _tasks = _filter_tasks_by_languages(_tasks, languages) - if script: - _tasks = _filter_tasks_by_script(_tasks, script) - if domains: - _tasks = _filter_tasks_by_domains(_tasks, domains) - if task_types: - _tasks = _filter_tasks_by_task_types(_tasks, task_types) - if categories: - logger.warning( - "`s2p`, `p2p`, and `s2s` will be removed and replaced by `t2t` in v2.0.0." - ) - _tasks = _filter_task_by_categories(_tasks, categories) - if exclude_superseded: - _tasks = _filter_superseded_datasets(_tasks) - if modalities: - _tasks = _filter_tasks_by_modalities( - _tasks, modalities, exclusive_modality_filter - ) - if exclude_aggregate: - _tasks = _filter_aggregate_tasks(_tasks) - - # Apply privacy filtering - if exclude_private: - _tasks = [t for t in _tasks if t.metadata.is_public] - return MTEBTasks(_tasks) @@ -401,8 +306,6 @@ def get_task( eval_splits: list[str] | None = None, hf_subsets: list[str] | None = None, exclusive_language_filter: bool = False, - modalities: list[Modalities] | None = None, - exclusive_modality_filter: bool = False, ) -> AbsTask: """Get a task by name. @@ -416,10 +319,6 @@ def get_task( exclusive_language_filter: Some datasets contains more than one language e.g. for STS22 the subset "de-en" contain eng and deu. If exclusive_language_filter is set to False both of these will be kept, but if set to True only those that contains all the languages specified will be kept. - modalities: A list of modalities to include. If None, all modalities are included. - exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the - task's modalities and ALL task modalities are in filter modalities (exact match). - If False, keep tasks if _any_ of the task's modalities match the filter modalities. Returns: An initialized task object. @@ -445,8 +344,6 @@ def get_task( task = _TASKS_REGISTRY[task_name]() if eval_splits: task.filter_eval_splits(eval_splits=eval_splits) - if modalities: - task.filter_modalities(modalities, exclusive_modality_filter) return task.filter_languages( languages, script, diff --git a/mteb/models/model_meta.py b/mteb/models/model_meta.py index b40fc9319a..40076c252b 100644 --- a/mteb/models/model_meta.py +++ b/mteb/models/model_meta.py @@ -303,7 +303,7 @@ def calculate_memory_usage_mb(self) -> int | None: def collect_similar_tasks(dataset: str, visited: set[str]) -> set[str]: """Recursively collect all similar tasks for a given dataset.""" - from mteb.overview import _SIMILAR_TASKS + from mteb.get_tasks import _SIMILAR_TASKS if dataset in visited: return set() diff --git a/mteb/results/task_result.py b/mteb/results/task_result.py index 15bacf3bb8..17dd6021cb 100644 --- a/mteb/results/task_result.py +++ b/mteb/results/task_result.py @@ -229,7 +229,7 @@ def languages(self) -> list[str]: @cached_property def task(self) -> AbsTask: """Get the task associated with the result.""" - from mteb.overview import get_task + from mteb.get_tasks import get_task return get_task(self.task_name) @@ -366,7 +366,7 @@ def _fix_pair_classification_scores(cls, obj: TaskResult) -> None: @classmethod def _convert_from_before_v1_11_0(cls, data: dict) -> Self: - from mteb.overview import _TASKS_REGISTRY + from mteb.get_tasks import _TASKS_REGISTRY # in case the task name is not found in the registry, try to find a lower case version lower_case_registry = {k.lower(): v for k, v in _TASKS_REGISTRY.items()} @@ -571,7 +571,7 @@ def validate_and_filter_scores(self, task: AbsTask | None = None) -> Self: task: The task to validate the scores against. E.g. if the task supplied is limited to certain splits and languages, the scores will be filtered to only include those splits and languages. If None it will attempt to get the task from the task_name. """ - from mteb.overview import get_task + from mteb.get_tasks import get_task if task is None: task = get_task(self.task_name) diff --git a/tests/test_abstasks/test_private_tasks.py b/tests/test_abstasks/test_private_tasks.py index 8b0b0d30bc..76ffe249bb 100644 --- a/tests/test_abstasks/test_private_tasks.py +++ b/tests/test_abstasks/test_private_tasks.py @@ -1,4 +1,4 @@ -from mteb.overview import get_tasks +from mteb.get_tasks import get_tasks # List of accepted private tasks - update this list as needed ACCEPTED_PRIVATE_TASKS = [ diff --git a/tests/test_benchmarks/test_get_benchmarks.py b/tests/test_benchmarks/test_get_benchmarks.py index 640d894875..3a898fba9f 100644 --- a/tests/test_benchmarks/test_get_benchmarks.py +++ b/tests/test_benchmarks/test_get_benchmarks.py @@ -5,7 +5,7 @@ import pytest import mteb -import mteb.overview +import mteb.get_tasks logging.basicConfig(level=logging.INFO) diff --git a/tests/test_benchmarks/test_names_must_be unique.py b/tests/test_benchmarks/test_names_must_be unique.py index a4f13b51af..7f51ed9db6 100644 --- a/tests/test_benchmarks/test_names_must_be unique.py +++ b/tests/test_benchmarks/test_names_must_be unique.py @@ -3,7 +3,7 @@ import logging import mteb -import mteb.overview +import mteb.get_tasks logging.basicConfig(level=logging.INFO) diff --git a/tests/test_deprecated/test_MTEB.py b/tests/test_deprecated/test_MTEB.py index 05efbc8264..8977ce3d63 100644 --- a/tests/test_deprecated/test_MTEB.py +++ b/tests/test_deprecated/test_MTEB.py @@ -8,7 +8,6 @@ import pytest import mteb -import mteb.overview logging.basicConfig(level=logging.INFO) diff --git a/tests/test_filter_tasks.py b/tests/test_filter_tasks.py new file mode 100644 index 0000000000..ba077b8094 --- /dev/null +++ b/tests/test_filter_tasks.py @@ -0,0 +1,143 @@ +import pytest + +from mteb import get_tasks +from mteb.abstasks.abstask import AbsTask +from mteb.abstasks.task_metadata import TaskDomain, TaskType +from mteb.filter_tasks import filter_tasks +from mteb.types import Modalities + + +@pytest.fixture +def all_tasks(): + return get_tasks() + + +def test_get_tasks_size_differences(all_tasks: list[AbsTask]): + assert len(all_tasks) > 0 + assert len(all_tasks) >= len(filter_tasks(all_tasks, script=["Latn"])) + assert len(all_tasks) >= len(filter_tasks(all_tasks, domains=["Legal"])) + assert len(all_tasks) >= len(filter_tasks(all_tasks, languages=["eng", "deu"])) + text_task = filter_tasks(all_tasks, modalities=["text"]) + assert len(all_tasks) >= len(text_task) + assert len(filter_tasks(all_tasks, modalities=["text", "image"])) >= len(text_task) + + +@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"], None]) +@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None]) +@pytest.mark.parametrize("domains", [["Legal"], ["Medical", "Non-fiction"], None]) +@pytest.mark.parametrize("task_types", [["Classification"], None]) +def test_filter_tasks( + all_tasks: list[AbsTask], + languages: list[str], + script: list[str], + domains: list[TaskDomain], + task_types: list[TaskType] | None, # type: ignore +): + """Tests that get_tasks filters tasks correctly. This could in principle be combined with the following tests, but they have been kept + seperate to reduce the grid size. + """ + tasks = filter_tasks( + all_tasks, + languages=languages, + script=script, + domains=domains, + task_types=task_types, + ) + + for task in tasks: + if languages: + assert set(languages).intersection(task.metadata.languages) + if script: + assert set(script).intersection(task.metadata.scripts) + if domains: + task_domains = ( + set(task.metadata.domains) if task.metadata.domains else set() + ) + assert set(domains).intersection(set(task_domains)) + if task_types: + assert task.metadata.type in task_types + + +@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"]]) +@pytest.mark.parametrize("domains", [["Medical", "Non-fiction"], None]) +@pytest.mark.parametrize("task_types", [["Classification"], None]) +@pytest.mark.parametrize("exclude_superseded_datasets", [True, False]) +def test_filter_tasks_superseded( + all_tasks: list[AbsTask], + languages: list[str], + domains: list[TaskDomain], + task_types: list[TaskType] | None, # type: ignore + exclude_superseded_datasets: bool, +): + tasks = filter_tasks( + all_tasks, + languages=languages, + domains=domains, + task_types=task_types, + exclude_superseded=exclude_superseded_datasets, + ) + + for task in tasks: + if languages: + assert set(languages).intersection(task.metadata.languages) + if domains: + task_domains = ( + set(task.metadata.domains) if task.metadata.domains else set() + ) + assert set(domains).intersection(set(task_domains)) + if task_types: + assert task.metadata.type in task_types + if exclude_superseded_datasets: + assert task.superseded_by is None + + +@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"]]) +@pytest.mark.parametrize("modalities", [["text"], ["image"], ["text", "image"], None]) +@pytest.mark.parametrize("exclusive_modality_filter", [True, False]) +def test_filter_tasks_modalities( + all_tasks: list[AbsTask], + languages: list[str], + modalities: list[Modalities] | None, + exclusive_modality_filter: bool, +): + tasks = filter_tasks( + all_tasks, + languages=languages, + modalities=modalities, + exclusive_modality_filter=exclusive_modality_filter, + ) + + for task in tasks: + if languages: + assert set(languages).intersection(task.metadata.languages) + if modalities: + if exclusive_modality_filter: + assert set(task.modalities) == set(modalities) + else: + assert any(mod in task.modalities for mod in modalities) + + +@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"], None]) +@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None]) +@pytest.mark.parametrize("exclude_aggregate", [True, False]) +def test_filter_tasks_exclude_aggregate( + all_tasks: list[AbsTask], + languages: list[str], + script: list[str], + exclude_aggregate: bool, +): + tasks = filter_tasks( + all_tasks, + languages=languages, + script=script, + exclude_aggregate=exclude_aggregate, + ) + + for task in tasks: + if languages: + assert set(languages).intersection(task.metadata.languages) + if script: + assert set(script).intersection(task.metadata.scripts) + if exclude_aggregate: + # Aggregate tasks should be excluded + assert not task.is_aggregate diff --git a/tests/test_get_tasks.py b/tests/test_get_tasks.py index 6bc3aec451..cf7ed0ad17 100644 --- a/tests/test_get_tasks.py +++ b/tests/test_get_tasks.py @@ -3,39 +3,27 @@ import mteb from mteb import get_task, get_tasks from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.task_metadata import TaskDomain, TaskType -from mteb.overview import MTEBTasks +from mteb.abstasks.task_metadata import TaskType +from mteb.get_tasks import MTEBTasks from mteb.types import Modalities -def test_get_tasks_size_differences(): - default_tasks = get_tasks() - assert len(default_tasks) > 0 - assert len(default_tasks) >= len(get_tasks(script=["Latn"])) - assert len(default_tasks) >= len(get_tasks(domains=["Legal"])) - assert len(default_tasks) >= len(get_tasks(languages=["eng", "deu"])) - text_task = get_tasks(modalities=["text"]) - assert len(default_tasks) >= len(text_task) - assert len(get_tasks(modalities=["text", "image"])) >= len(text_task) +@pytest.fixture +def all_tasks(): + return get_tasks() @pytest.mark.parametrize( "task_name", ["BornholmBitextMining", "CQADupstackRetrieval", "Birdsnap"] ) @pytest.mark.parametrize("eval_splits", [["test"], None]) -@pytest.mark.parametrize("modalities", [["text"], None]) -@pytest.mark.parametrize("exclusive_modality_filter", [True, False]) def test_get_task( task_name: str, eval_splits: list[str] | None, - modalities: list[Modalities] | None, - exclusive_modality_filter: bool, ): task = get_task( task_name, eval_splits=eval_splits, - modalities=modalities, - exclusive_modality_filter=exclusive_modality_filter, ) assert isinstance(task, AbsTask) assert task.metadata.name == task_name @@ -45,128 +33,6 @@ def test_get_task( else: assert task.eval_splits == task.metadata.eval_splits - if modalities: - if task.modalities: - if exclusive_modality_filter: - # With exclusive filter, task modalities must exactly match the requested modalities - assert set(task.modalities) == set(modalities) - else: - # With inclusive filter, task modalities must have overlap with requested modalities - assert any(mod in task.modalities for mod in modalities) - - -@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"], None]) -@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None]) -@pytest.mark.parametrize("domains", [["Legal"], ["Medical", "Non-fiction"], None]) -@pytest.mark.parametrize("task_types", [["Classification"], None]) -def test_get_tasks( - languages: list[str], - script: list[str], - domains: list[TaskDomain], - task_types: list[TaskType] | None, # type: ignore -): - """Tests that get_tasks filters tasks correctly. This could in principle be combined with the following tests, but they have been kept - seperate to reduce the grid size. - """ - tasks = mteb.get_tasks( - languages=languages, - script=script, - domains=domains, - task_types=task_types, - ) - - for task in tasks: - if languages: - assert set(languages).intersection(task.metadata.languages) - if script: - assert set(script).intersection(task.metadata.scripts) - if domains: - task_domains = ( - set(task.metadata.domains) if task.metadata.domains else set() - ) - assert set(domains).intersection(set(task_domains)) - if task_types: - assert task.metadata.type in task_types - - -@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"]]) -@pytest.mark.parametrize("domains", [["Medical", "Non-fiction"], None]) -@pytest.mark.parametrize("task_types", [["Classification"], None]) -@pytest.mark.parametrize("exclude_superseded_datasets", [True, False]) -def test_get_tasks_superseded( - languages: list[str], - domains: list[TaskDomain], - task_types: list[TaskType] | None, # type: ignore - exclude_superseded_datasets: bool, -): - tasks = mteb.get_tasks( - languages=languages, - domains=domains, - task_types=task_types, - exclude_superseded=exclude_superseded_datasets, - ) - - for task in tasks: - if languages: - assert set(languages).intersection(task.metadata.languages) - if domains: - task_domains = ( - set(task.metadata.domains) if task.metadata.domains else set() - ) - assert set(domains).intersection(set(task_domains)) - if task_types: - assert task.metadata.type in task_types - if exclude_superseded_datasets: - assert task.superseded_by is None - - -@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"]]) -@pytest.mark.parametrize("modalities", [["text"], ["image"], ["text", "image"], None]) -@pytest.mark.parametrize("exclusive_modality_filter", [True, False]) -def test_get_tasks_modalities( - languages: list[str], - modalities: list[Modalities] | None, - exclusive_modality_filter: bool, -): - tasks = mteb.get_tasks( - languages=languages, - modalities=modalities, - exclusive_modality_filter=exclusive_modality_filter, - ) - - for task in tasks: - if languages: - assert set(languages).intersection(task.metadata.languages) - if modalities: - if exclusive_modality_filter: - assert set(task.modalities) == set(modalities) - else: - assert any(mod in task.modalities for mod in modalities) - - -@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"], None]) -@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None]) -@pytest.mark.parametrize("exclude_aggregate", [True, False]) -def test_get_tasks_exclude_aggregate( - languages: list[str], - script: list[str], - exclude_aggregate: bool, -): - tasks = mteb.get_tasks( - languages=languages, - script=script, - exclude_aggregate=exclude_aggregate, - ) - - for task in tasks: - if languages: - assert set(languages).intersection(task.metadata.languages) - if script: - assert set(script).intersection(task.metadata.scripts) - if exclude_aggregate: - # Aggregate tasks should be excluded - assert not task.is_aggregate - def test_get_tasks_filtering(): """Tests that get_tasks filters tasks for languages within the task, i.e. that a multilingual task returns only relevant subtasks for the diff --git a/tests/test_integrations/test_encode_args_passed.py b/tests/test_integrations/test_encode_args_passed.py index af9611f667..5eba02e133 100644 --- a/tests/test_integrations/test_encode_args_passed.py +++ b/tests/test_integrations/test_encode_args_passed.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader import mteb -import mteb.overview +import mteb.get_tasks from mteb.abstasks import AbsTask from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder diff --git a/tests/test_integrations/test_modality.py b/tests/test_integrations/test_modality.py index 1c29f32b80..ffd00a1598 100644 --- a/tests/test_integrations/test_modality.py +++ b/tests/test_integrations/test_modality.py @@ -8,7 +8,7 @@ import pytest import mteb -import mteb.overview +import mteb.get_tasks from mteb.MTEB import logger from tests.mock_tasks import ( MockImageClusteringTask, diff --git a/tests/test_integrations/test_prompts.py b/tests/test_integrations/test_prompts.py index 54f8fcdcee..66155afe87 100644 --- a/tests/test_integrations/test_prompts.py +++ b/tests/test_integrations/test_prompts.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader import mteb -import mteb.overview +import mteb.get_tasks from mteb.abstasks import AbsTask from mteb.models.abs_encoder import AbsEncoder from tests.mock_models import ( diff --git a/tests/test_tasks/test_dataset_on_hf.py b/tests/test_tasks/test_dataset_on_hf.py index df46da8bb3..9bf6847169 100644 --- a/tests/test_tasks/test_dataset_on_hf.py +++ b/tests/test_tasks/test_dataset_on_hf.py @@ -6,7 +6,7 @@ import mteb from mteb.abstasks.aggregated_task import AbsTaskAggregate -from mteb.overview import get_tasks +from mteb.get_tasks import get_tasks from tests.task_grid import ( MOCK_MIEB_TASK_GRID_AS_STRING, MOCK_TASK_TEST_GRID_AS_STRING, diff --git a/tests/test_tasks/test_is_superseeded.py b/tests/test_tasks/test_is_superseeded.py index 0166efd49c..99748e4c19 100644 --- a/tests/test_tasks/test_is_superseeded.py +++ b/tests/test_tasks/test_is_superseeded.py @@ -5,7 +5,7 @@ import logging import mteb -from mteb.overview import _TASKS_REGISTRY +from mteb.get_tasks import _TASKS_REGISTRY logging.basicConfig(level=logging.INFO) diff --git a/tests/test_tasks/test_metadata.py b/tests/test_tasks/test_metadata.py index cf081af325..1f13af7491 100644 --- a/tests/test_tasks/test_metadata.py +++ b/tests/test_tasks/test_metadata.py @@ -3,7 +3,7 @@ import pytest from mteb.abstasks import AbsTask -from mteb.overview import get_tasks +from mteb.get_tasks import get_tasks # Historic datasets without filled metadata. Do NOT add new datasets to this list. _HISTORIC_DATASETS = [