Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/task.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ A task is an implementation of a dataset for evaluation. It could, for instance,

:::mteb.get_task

:::mteb.filter_tasks

## Metadata

Each task also contains extensive metadata. We annotate this using the following object, which allows us to use [pydantic](https://docs.pydantic.dev/latest/) to validate the metadata.
Expand Down
2 changes: 1 addition & 1 deletion docs/overview/create_available_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import cast

import mteb
from mteb.overview import MTEBTasks
from mteb.get_tasks import MTEBTasks

START_INSERT = "<!-- START TASK DESCRIPTION -->"
END_INSERT = "<!-- END TASK DESCRIPTION -->"
Expand Down
2 changes: 1 addition & 1 deletion mteb/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _display_tasks(self, task_list: Iterable[AbsTask], name: str | None = None):

def mteb_benchmarks(self):
"""Get all benchmarks available in the MTEB."""
from mteb.overview import MTEBTasks
from mteb.get_tasks import MTEBTasks

# get all the MTEB specific benchmarks:
sorted_mteb_benchmarks = sorted(
Expand Down
4 changes: 3 additions & 1 deletion mteb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from mteb.abstasks import AbsTask
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.evaluate import evaluate
from mteb.filter_tasks import filter_tasks
from mteb.get_tasks import get_task, get_tasks
from mteb.load_results import load_results
from mteb.models import Encoder, SentenceTransformerEncoderWrapper
from mteb.models.get_model_meta import get_model, get_model_meta, get_model_metas
from mteb.MTEB import MTEB
from mteb.overview import get_task, get_tasks
from mteb.results import BenchmarkResults, TaskResult

from .benchmarks.benchmark import Benchmark
Expand All @@ -25,6 +26,7 @@
"TaskMetadata",
"TaskResult",
"evaluate",
"filter_tasks",
"get_benchmark",
"get_benchmarks",
"get_model",
Expand Down
26 changes: 0 additions & 26 deletions mteb/abstasks/abstask.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,32 +432,6 @@ def filter_eval_splits(self, eval_splits: list[str] | None) -> Self:
self._eval_splits = eval_splits
return self

def filter_modalities(
self, modalities: list[str] | None, exclusive_modality_filter: bool = False
) -> Self:
"""Filter the modalities of the task.

Args:
modalities: A list of modalities to filter by. If None, the task is returned unchanged.
exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
task's modalities and ALL task modalities are in filter modalities (exact match).
If False, keep tasks if _any_ of the task's modalities match the filter modalities.

Returns:
The filtered task
"""
if modalities is None:
return self
filter_modalities_set = set(modalities)
task_modalities_set = set(self.modalities)
if exclusive_modality_filter:
if not (filter_modalities_set == task_modalities_set):
self.metadata.modalities = []
else:
if not filter_modalities_set.intersection(task_modalities_set):
self.metadata.modalities = []
return self

def filter_languages(
self,
languages: list[str] | None,
Expand Down
2 changes: 1 addition & 1 deletion mteb/benchmarks/_create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd

import mteb
from mteb.overview import get_task, get_tasks
from mteb.get_tasks import get_task, get_tasks
from mteb.results.benchmark_results import BenchmarkResults


Expand Down
2 changes: 1 addition & 1 deletion mteb/benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mteb.benchmarks.benchmark import Benchmark, HUMEBenchmark, MIEBBenchmark
from mteb.overview import MTEBTasks, get_task, get_tasks
from mteb.get_tasks import MTEBTasks, get_task, get_tasks

MMTEB_CITATION = r"""@article{enevoldsen2025mmtebmassivemultilingualtext,
author = {Kenneth Enevoldsen and Isaac Chung and Imene Kerboua and Márton Kardos and Ashwin Mathur and David Stap and Jay Gala and Wissam Siblini and Dominik Krzemiński and Genta Indra Winata and Saba Sturua and Saiteja Utpala and Mathieu Ciancone and Marion Schaeffer and Gabriel Sequeira and Diganta Misra and Shreeya Dhakal and Jonathan Rystrøm and Roman Solomatin and Ömer Çağatan and Akash Kundu and Martin Bernstorff and Shitao Xiao and Akshita Sukhlecha and Bhavish Pahwa and Rafał Poświata and Kranthi Kiran GV and Shawon Ashraf and Daniel Auras and Björn Plüster and Jan Philipp Harries and Loïc Magne and Isabelle Mohr and Mariya Hendriksen and Dawei Zhu and Hippolyte Gisserot-Boukhlef and Tom Aarsen and Jan Kostkan and Konrad Wojtasik and Taemin Lee and Marek Šuppa and Crystina Zhang and Roberta Rocca and Mohammed Hamdy and Andrianos Michail and John Yang and Manuel Faysse and Aleksei Vatolin and Nandan Thakur and Manan Dey and Dipam Vasani and Pranjal Chitale and Simone Tedeschi and Nguyen Tai and Artem Snegirev and Michael Günther and Mengzhou Xia and Weijia Shi and Xing Han Lù and Jordan Clive and Gayatri Krishnakumar and Anna Maksimova and Silvan Wehrli and Maria Tikhonova and Henil Panchal and Aleksandr Abramov and Malte Ostendorff and Zheng Liu and Simon Clematide and Lester James Miranda and Alena Fenogenova and Guangyu Song and Ruqiya Bin Safi and Wen-Ding Li and Alessia Borghini and Federico Cassano and Hongjin Su and Jimmy Lin and Howard Yen and Lasse Hansen and Sara Hooker and Chenghao Xiao and Vaibhav Adlakha and Orion Weller and Siva Reddy and Niklas Muennighoff},
Expand Down
2 changes: 1 addition & 1 deletion mteb/benchmarks/benchmarks/rteb_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from mteb.benchmarks.benchmark import RtebBenchmark
from mteb.overview import get_tasks
from mteb.get_tasks import get_tasks

RTEB_CITATION = r"""@article{rteb2025,
author = {Liu, Frank and Enevoldsen, Kenneth and Solomatin, Roman and Chung, Isaac and Aarsen, Tom and Fődi, Zoltán},
Expand Down
2 changes: 1 addition & 1 deletion mteb/cli/_display_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mteb.abstasks import AbsTask
from mteb.benchmarks import Benchmark
from mteb.overview import MTEBTasks
from mteb.get_tasks import MTEBTasks


def _display_benchmarks(benchmarks: Sequence[Benchmark]) -> None:
Expand Down
172 changes: 172 additions & 0 deletions mteb/filter_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""This script contains functions that are used to get an overview of the MTEB benchmark."""

import logging
from collections.abc import Sequence
from typing import overload

from mteb.abstasks import (
AbsTask,
)
from mteb.abstasks.task_metadata import TaskCategory, TaskDomain, TaskType
from mteb.languages import (
ISO_TO_LANGUAGE,
ISO_TO_SCRIPT,
)
from mteb.types import Modalities

logger = logging.getLogger(__name__)


def _check_is_valid_script(script: str) -> None:
if script not in ISO_TO_SCRIPT:
raise ValueError(
f"Invalid script code: '{script}', you can see valid ISO 15924 codes using `from mteb.languages import ISO_TO_SCRIPT`."
)


def _check_is_valid_language(lang: str) -> None:
if lang not in ISO_TO_LANGUAGE:
raise ValueError(
f"Invalid language code: '{lang}', you can see valid ISO 639-3 codes using `from mteb.languages import ISO_TO_LANGUAGE`."
)


@overload
def filter_tasks(
tasks: Sequence[AbsTask],
*,
languages: list[str] | None = None,
script: list[str] | None = None,
domains: list[TaskDomain] | None = None,
task_types: list[TaskType] | None = None, # type: ignore
categories: list[TaskCategory] | None = None,
modalities: list[Modalities] | None = None,
exclusive_modality_filter: bool = False,
exclude_superseded: bool = False,
exclude_aggregate: bool = False,
exclude_private: bool = False,
) -> list[AbsTask]: ...


@overload
def filter_tasks(
tasks: Sequence[type[AbsTask]],
*,
languages: list[str] | None = None,
script: list[str] | None = None,
domains: list[TaskDomain] | None = None,
task_types: list[TaskType] | None = None, # type: ignore
categories: list[TaskCategory] | None = None,
modalities: list[Modalities] | None = None,
exclusive_modality_filter: bool = False,
exclude_superseded: bool = False,
exclude_aggregate: bool = False,
exclude_private: bool = False,
) -> list[type[AbsTask]]: ...


def filter_tasks(
tasks: Sequence[AbsTask] | Sequence[type[AbsTask]],
*,
languages: list[str] | None = None,
script: list[str] | None = None,
domains: list[TaskDomain] | None = None,
task_types: list[TaskType] | None = None, # type: ignore
categories: list[TaskCategory] | None = None,
modalities: list[Modalities] | None = None,
exclusive_modality_filter: bool = False,
exclude_superseded: bool = False,
exclude_aggregate: bool = False,
exclude_private: bool = False,
) -> list[AbsTask] | list[type[AbsTask]]:
"""Filter tasks based on the specified criteria.

Args:
tasks: A list of task names to include. If None, all tasks which pass the filters are included. If passed, other filters are ignored.
languages: A list of languages either specified as 3 letter languages codes (ISO 639-3, e.g. "eng") or as script languages codes e.g.
"eng-Latn". For multilingual tasks this will also remove languages that are not in the specified list.
script: A list of script codes (ISO 15924 codes, e.g. "Latn"). If None, all scripts are included. For multilingual tasks this will also remove scripts
that are not in the specified list.
domains: A list of task domains, e.g. "Legal", "Medical", "Fiction".
task_types: A string specifying the type of task e.g. "Classification" or "Retrieval". If None, all tasks are included.
categories: A list of task categories these include "t2t" (text to text), "t2i" (text to image). See TaskMetadata for the full list.
exclude_superseded: A boolean flag to exclude datasets which are superseded by another.
eval_splits: A list of evaluation splits to include. If None, all splits are included.
modalities: A list of modalities to include. If None, all modalities are included.
exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
task's modalities and ALL task modalities are in filter modalities (exact match).
If False, keep tasks if _any_ of the task's modalities match the filter modalities.
exclude_aggregate: If True, exclude aggregate tasks. If False, both aggregate and non-aggregate tasks are returned.
exclude_private: If True (default), exclude private/closed datasets (is_public=False). If False, include both public and private datasets.

Returns:
A list of tasks objects which pass all of the filters.

Examples:
>>> text_classification_tasks = filter_tasks(my_tasks, task_types=["Classification"], modalities=["text"])
>>> medical_tasks = filter_tasks(my_tasks, domains=["Medical"])
>>> english_tasks = filter_tasks(my_tasks, languages=["eng"])
>>> latin_script_tasks = filter_tasks(my_tasks, script=["Latn"])
>>> text_image_tasks = filter_tasks(my_tasks, modalities=["text", "image"], exclusive_modality_filter=True)

"""
langs_to_keep = None
if languages:
[_check_is_valid_language(lang) for lang in languages]
langs_to_keep = set(languages)

script_to_keep = None
if script:
[_check_is_valid_script(s) for s in script]
script_to_keep = set(script)

domains_to_keep = None
if domains:
domains_to_keep = set(domains)

def _convert_to_set(domain: list[TaskDomain] | None) -> set:
return set(domain) if domain is not None else set()

task_types_to_keep = None
if task_types:
task_types_to_keep = set(task_types)

categories_to_keep = None
if categories:
categories_to_keep = set(categories)

modalities_to_keep = None
if modalities:
modalities_to_keep = set(modalities)

_tasks = []
for t in tasks:
if langs_to_keep and not langs_to_keep.intersection(t.metadata.languages):
continue
if script_to_keep and not script_to_keep.intersection(t.metadata.scripts):
continue
if domains_to_keep and not domains_to_keep.intersection(
_convert_to_set(t.metadata.domains)
):
continue
if task_types_to_keep and t.metadata.type not in task_types_to_keep:
continue
if categories_to_keep and t.metadata.category not in categories_to_keep:
continue
if modalities_to_keep:
if exclusive_modality_filter:
if set(t.metadata.modalities) != modalities_to_keep:
continue
else:
if not modalities_to_keep.intersection(t.metadata.modalities):
continue
if exclude_superseded and t.superseded_by is not None:
continue
if exclude_aggregate and t.is_aggregate:
continue
if exclude_private and not t.metadata.is_public:
continue

_tasks.append(t)

return _tasks
Loading