Skip to content
2 changes: 1 addition & 1 deletion docs/adding_a_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ evaluation = MTEB(tasks=[SciDocsReranking()])
evaluation.run(model)
```

> **Note:** for multilingual / crosslingual tasks, make sure your class also inherits from the `MultilingualTask` class like in [this](https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Classification/multilingual/MTOPIntentClassification.py) example.
> **Note:** for multilingual / crosslingual tasks, make sure you've specified `eval_langs` as a dictionary, as shown in [this example](../mteb/tasks/Classification/multilingual/MTOPIntentClassification.py).



Expand Down
47 changes: 42 additions & 5 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class AbsTask(ABC):
abstask_prompt: The potential prompt of the abstask
superseded_by: Denotes the task that this task is superseeded by. Used to issue warning to users of outdated datasets, while maintaining
reproducibility of existing benchmarks.
fast_loading: (Not recommended to use) Denotes if the task should be loaded using the fast loading method.
This is only possible if the dataset have a "default" config. We don't recommend to use this method, and suggest to use different subsets for loading datasets.
This was used only for historical reasons and will be removed in the future.
"""

metadata: TaskMetadata
Expand All @@ -82,8 +85,8 @@ class AbsTask(ABC):
superseded_by: str | None = None
dataset: dict[HFSubset, DatasetDict] | None = None # type: ignore
data_loaded: bool = False
is_multilingual: bool = False
hf_subsets: list[HFSubset] | None = None
fast_loading: bool = False

def __init__(self, seed: int = 42, **kwargs: Any):
"""The init function. This is called primarily to set the seed.
Expand All @@ -96,6 +99,7 @@ def __init__(self, seed: int = 42, **kwargs: Any):

self.seed = seed
self.rng_state, self.np_rng = set_seed(seed)
self.hf_subsets = self.metadata.hf_subsets

def check_if_dataset_is_superseded(self):
"""Check if the dataset is superseded by a newer version"""
Expand Down Expand Up @@ -224,10 +228,43 @@ def load_data(self, **kwargs):
"""
if self.data_loaded:
return
self.dataset = datasets.load_dataset(**self.metadata.dataset) # type: ignore
if self.metadata.is_multilingual:
if self.fast_loading:
self.fast_load()
else:
self.dataset = {}
for hf_subset in self.hf_subsets:
self.dataset[hf_subset] = datasets.load_dataset(
name=hf_subset,
**self.metadata.dataset,
)
else:
# some of monolingual datasets explicitly adding the split name to the dataset name
self.dataset = datasets.load_dataset(**self.metadata.dataset) # type: ignore
self.dataset_transform()
self.data_loaded = True

def fast_load(self, **kwargs):
# todo remove
"""Load all subsets at once, then group by language with Polars. Using fast loading has two requirements:
- Each row in the dataset should have a 'lang' feature giving the corresponding language/language pair
- The datasets must have a 'default' config that loads all the subsets of the dataset (see https://huggingface.co/docs/datasets/en/repository_structure#configurations)
"""
self.dataset = {}
merged_dataset = datasets.load_dataset(
**self.metadata.dataset
) # load "default" subset
for split in merged_dataset.keys():
df_split = merged_dataset[split].to_polars()
df_grouped = dict(df_split.group_by(["lang"]))
for lang in set(df_split["lang"].unique()) & set(self.hf_subsets):
self.dataset.setdefault(lang, {})
self.dataset[lang][split] = datasets.Dataset.from_polars(
df_grouped[(lang,)].drop("lang")
) # Remove lang column and convert back to HF datasets, not strictly necessary but better for compatibility
for lang, subset in self.dataset.items():
self.dataset[lang] = datasets.DatasetDict(subset)

def calculate_metadata_metrics(
self, overwrite_results: bool = False
) -> dict[str, DescriptiveStatistics | dict[str, DescriptiveStatistics]]:
Expand All @@ -248,14 +285,14 @@ def calculate_metadata_metrics(
for split in pbar_split:
pbar_split.set_postfix_str(f"Split: {split}")
logger.info(f"Processing metadata for split {split}")
if self.is_multilingual:
if self.metadata.is_multilingual:
descriptive_stats[split] = self._calculate_metrics_from_split(
split, compute_overall=True
)
descriptive_stats[split][hf_subset_stat] = {}

pbar_subsets = tqdm.tqdm(
self.metadata.hf_subsets_to_langscripts,
self.metadata.hf_subsets,
desc="Processing Languages...",
)
for hf_subset in pbar_subsets:
Expand Down Expand Up @@ -345,7 +382,7 @@ def _add_main_score(self, scores: dict[HFSubset, ScoresDict]) -> None:
scores["main_score"] = scores[self.metadata.main_score]

def _upload_dataset_to_hub(self, repo_name: str, fields: list[str]) -> None:
if self.is_multilingual:
if self.metadata.is_multilingual:
for config in self.metadata.eval_langs:
logger.info(f"Converting {config} of {self.metadata.name}")
sentences = {}
Expand Down
4 changes: 2 additions & 2 deletions mteb/abstasks/AbsTaskBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def evaluate(
if not self.data_loaded:
self.load_data()

hf_subsets = list(self.dataset) if self.is_multilingual else ["default"]
hf_subsets = self.hf_subsets

# If subsets_to_run is specified, filter the hf_subsets accordingly
if subsets_to_run is not None:
Expand Down Expand Up @@ -191,7 +191,7 @@ def _calculate_metrics_from_split(
)

def _push_dataset_to_hub(self, repo_name: str) -> None:
if self.is_multilingual:
if self.metadata.is_multilingual:
for config in self.metadata.eval_langs:
logger.info(f"Converting {config} of {self.metadata.name}")

Expand Down
2 changes: 1 addition & 1 deletion mteb/abstasks/AbsTaskClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def evaluate(
self.load_data()

scores = {}
hf_subsets = list(self.dataset) if self.is_multilingual else ["default"]
hf_subsets = self.hf_subsets
if subsets_to_run is not None:
hf_subsets = [s for s in hf_subsets if s in subsets_to_run]

Expand Down
2 changes: 1 addition & 1 deletion mteb/abstasks/AbsTaskReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def transform_old_dataset_format(self, given_dataset=None):
self.relevant_docs = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
self.top_ranked = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

hf_subsets = list(self.hf_subsets) if self.is_multilingual else ["default"]
hf_subsets = self.hf_subsets

for hf_subset in hf_subsets:
if given_dataset:
Expand Down
6 changes: 3 additions & 3 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def process_data(split: str, lang: str | None = None):
else:
self.top_ranked[split] = top_ranked

if self.is_multilingual:
if self.metadata.is_multilingual:
for lang in self.metadata.eval_langs:
for split in eval_splits:
process_data(split, lang)
Expand All @@ -204,7 +204,7 @@ def evaluate(
)

scores = {}
hf_subsets = list(self.hf_subsets) if self.is_multilingual else ["default"]
hf_subsets = self.hf_subsets
if subsets_to_run is not None:
hf_subsets = [s for s in hf_subsets if s in subsets_to_run]

Expand Down Expand Up @@ -515,7 +515,7 @@ def format_text_field(text):
return text
return f"{text.get('title', '')} {text.get('text', '')}".strip()

if self.is_multilingual:
if self.metadata.is_multilingual:
for config in self.queries:
logger.info(f"Converting {config} of {self.metadata.name}")

Expand Down
61 changes: 0 additions & 61 deletions mteb/abstasks/MultilingualTask.py

This file was deleted.

10 changes: 10 additions & 0 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,5 +446,15 @@ def n_samples(self) -> dict[str, int] | None:
n_samples[subset] = subset_value["num_samples"]
return n_samples

@property
def hf_subsets(self) -> list[str]:
"""Return the huggingface subsets."""
return list(self.hf_subsets_to_langscripts.keys())

@property
def is_multilingual(self) -> bool:
"""Check if the task is multilingual."""
return isinstance(self.eval_langs, dict)

def __hash__(self) -> int:
return hash(self.model_dump_json())
2 changes: 0 additions & 2 deletions mteb/abstasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .AbsTaskSpeedTask import AbsTaskSpeedTask
from .AbsTaskSTS import AbsTaskSTS
from .AbsTaskSummarization import AbsTaskSummarization
from .MultilingualTask import MultilingualTask
from .TaskMetadata import TaskMetadata

__all__ = [
Expand All @@ -28,6 +27,5 @@
"AbsTaskSpeedTask",
"AbsTaskSTS",
"AbsTaskSummarization",
"MultilingualTask",
"TaskMetadata",
]
4 changes: 2 additions & 2 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _display_tasks(self, task_list: Iterable[AbsTask], name: str | None = None):
category = f", [italic grey39]{task.metadata.category}[/]"
multilingual = (
f", [italic red]multilingual {len(task.hf_subsets)} / {len(task.metadata.eval_langs)} Subsets[/]"
if task.is_multilingual
if task.metadata.is_multilingual
else ""
)
console.print(f"{prefix}{name}{category}{multilingual}")
Expand Down Expand Up @@ -326,7 +326,7 @@ def run(
task_eval_splits = (
eval_splits if eval_splits is not None else task.eval_splits
)
task_subsets = list(task.metadata.hf_subsets_to_langscripts.keys())
task_subsets = task.metadata.hf_subsets

existing_results = None
save_path = None
Expand Down
2 changes: 1 addition & 1 deletion mteb/load_results/task_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def validate_and_filter_scores(self, task: AbsTask | None = None) -> TaskResult:
if task is None:
task = get_task(self.task_name)
splits = task.metadata.eval_splits
if task.is_multilingual:
if task.metadata.is_multilingual:
hf_subsets = getattr(
task, "hf_subsets", task.metadata.hf_subsets_to_langscripts.keys()
)
Expand Down
3 changes: 1 addition & 2 deletions mteb/tasks/BitextMining/kat/TbilisiCityHallBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from datasets import DatasetDict, load_dataset

from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining
from mteb.abstasks.MultilingualTask import MultilingualTask
from mteb.abstasks.TaskMetadata import TaskMetadata

_LANGUAGES = {
Expand All @@ -18,7 +17,7 @@
_EVAL_SPLIT = "test"


class TbilisiCityHallBitextMining(AbsTaskBitextMining, MultilingualTask):
class TbilisiCityHallBitextMining(AbsTaskBitextMining):
metadata = TaskMetadata(
name="TbilisiCityHallBitextMining",
dataset={
Expand Down
3 changes: 1 addition & 2 deletions mteb/tasks/BitextMining/multilingual/BUCCBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging

from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining
from mteb.abstasks.MultilingualTask import MultilingualTask
from mteb.abstasks.TaskMetadata import TaskMetadata

_LANGUAGES = {
Expand All @@ -19,7 +18,7 @@
logger = logging.getLogger(__name__)


class BUCCBitextMining(AbsTaskBitextMining, MultilingualTask):
class BUCCBitextMining(AbsTaskBitextMining):
superseded_by = "BUCC.v2"
metadata = TaskMetadata(
name="BUCC",
Expand Down
3 changes: 1 addition & 2 deletions mteb/tasks/BitextMining/multilingual/BUCCBitextMiningFast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining
from mteb.abstasks.MultilingualTask import MultilingualTask
from mteb.abstasks.TaskMetadata import TaskMetadata

_LANGUAGES = {
Expand All @@ -15,7 +14,7 @@
_SPLITS = ["test"]


class BUCCBitextMiningFast(AbsTaskBitextMining, MultilingualTask):
class BUCCBitextMiningFast(AbsTaskBitextMining):
fast_loading = True
metadata = TaskMetadata(
name="BUCC.v2",
Expand Down
3 changes: 1 addition & 2 deletions mteb/tasks/BitextMining/multilingual/BibleNLPBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import datasets

from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining
from mteb.abstasks.MultilingualTask import MultilingualTask
from mteb.abstasks.TaskMetadata import TaskMetadata

_LANGUAGES = [
Expand Down Expand Up @@ -859,7 +858,7 @@ def extend_lang_pairs_english_centric() -> dict[str, list[str]]:
_LANGUAGES_MAPPING = extend_lang_pairs_english_centric()


class BibleNLPBitextMining(AbsTaskBitextMining, MultilingualTask):
class BibleNLPBitextMining(AbsTaskBitextMining):
metadata = TaskMetadata(
name="BibleNLPBitextMining",
dataset={
Expand Down
3 changes: 1 addition & 2 deletions mteb/tasks/BitextMining/multilingual/DiaBLaBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import datasets

from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining
from mteb.abstasks.MultilingualTask import MultilingualTask
from mteb.abstasks.TaskMetadata import TaskMetadata


class DiaBLaBitextMining(AbsTaskBitextMining, MultilingualTask):
class DiaBLaBitextMining(AbsTaskBitextMining):
metadata = TaskMetadata(
name="DiaBlaBitextMining",
dataset={
Expand Down
3 changes: 1 addition & 2 deletions mteb/tasks/BitextMining/multilingual/FloresBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import datasets

from mteb.abstasks.AbsTaskBitextMining import AbsTaskBitextMining
from mteb.abstasks.MultilingualTask import MultilingualTask
from mteb.abstasks.TaskMetadata import TaskMetadata

_LANGUAGES = [
Expand Down Expand Up @@ -235,7 +234,7 @@ def extend_lang_pairs() -> dict[str, list[str]]:
_LANGUAGES_MAPPING = extend_lang_pairs()


class FloresBitextMining(AbsTaskBitextMining, MultilingualTask):
class FloresBitextMining(AbsTaskBitextMining):
parallel_subsets = True
metadata = TaskMetadata(
name="FloresBitextMining",
Expand Down
Loading