Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mteb/abstasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
return

dataset_path = self.metadata.dataset["path"]
eval_splits = self.metadata.eval_splits
eval_splits = self.eval_splits
trust_remote_code = self.metadata.dataset.get("trust_remote_code", False)
revision = self.metadata.dataset["revision"]

Expand All @@ -284,7 +284,7 @@ def _process_data(split: str, hf_subset: str = "default"):
)

if self.metadata.is_multilingual:
for lang in self.metadata.eval_langs:
for lang in self.hf_subsets:
for split in eval_splits:
_process_data(split, lang)
else:
Expand Down
18 changes: 8 additions & 10 deletions mteb/abstasks/retrieval_dataset_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def _load_dataset_split(self, config: str, num_proc: int) -> Dataset:
)

def _load_corpus(self, num_proc: int) -> CorpusDatasetType:
logger.info("Loading Corpus...")

config = f"{self.config}-corpus" if self.config is not None else "corpus"
logger.info("Loading corpus subset: %s", config)

corpus_ds = self._load_dataset_split(config, num_proc)
if "_id" in corpus_ds.column_names:
corpus_ds = corpus_ds.cast_column("_id", Value("string")).rename_column(
Expand All @@ -152,9 +152,9 @@ def _load_corpus(self, num_proc: int) -> CorpusDatasetType:
return corpus_ds

def _load_queries(self, num_proc: int) -> QueryDatasetType:
logger.info("Loading Queries...")

config = f"{self.config}-queries" if self.config is not None else "queries"
logger.info("Loading queries subset: %s", config)

if "query" in self.dataset_configs:
config = "query"
queries_ds = self._load_dataset_split(config, num_proc)
Expand All @@ -169,9 +169,9 @@ def _load_queries(self, num_proc: int) -> QueryDatasetType:
return queries_ds

def _load_qrels(self, num_proc: int) -> RelevantDocumentsType:
logger.info("Loading qrels...")

config = f"{self.config}-qrels" if self.config is not None else "default"

logger.info("Loading qrels subset: %s", config)
if config == "default" and config not in self.dataset_configs:
if "qrels" in self.dataset_configs:
config = "qrels"
Expand Down Expand Up @@ -204,11 +204,10 @@ def _load_qrels(self, num_proc: int) -> RelevantDocumentsType:
return qrels_dict

def _load_top_ranked(self, num_proc: int) -> TopRankedDocumentsType:
logger.info("Loading Top Ranked")

config = (
f"{self.config}-top_ranked" if self.config is not None else "top_ranked"
)
logger.info("Loading top ranked subset: %s", config)
top_ranked_ds = self._load_dataset_split(config, num_proc)
top_ranked_ds = top_ranked_ds.cast(
Features(
Expand All @@ -228,11 +227,10 @@ def _load_top_ranked(self, num_proc: int) -> TopRankedDocumentsType:
return top_ranked_dict

def _load_instructions(self, num_proc: int) -> InstructionDatasetType:
logger.info("Loading Instructions")

config = (
f"{self.config}-instruction" if self.config is not None else "instruction"
)
logger.info("Loading instruction subset: %s", config)
instructions_ds = self._load_dataset_split(config, num_proc)
instructions_ds = instructions_ds.cast(
Features(
Expand Down
52 changes: 52 additions & 0 deletions tests/test_tasks/test_load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from unittest.mock import patch

import pytest
from datasets import Dataset, DatasetDict

import mteb


@pytest.mark.parametrize(
"task",
[
mteb.get_task("DiaBlaBitextMining", hf_subsets=["fr-en"]),
mteb.get_task("AmazonCounterfactualClassification", hf_subsets=["en"]),
mteb.get_task("WikiClusteringP2P", hf_subsets=["bs"]),
mteb.get_task("MultiEURLEXMultilabelClassification", hf_subsets=["en"]),
mteb.get_task("OpusparcusPC", hf_subsets=["en"]),
mteb.get_task("STS17MultilingualVisualSTS", hf_subsets=["en-en"]),
],
)
def test_multilingual_load_data(task):
dummy_dataset = DatasetDict({"test": Dataset.from_dict({"text": ["test"]})})

with patch("mteb.abstasks.abstask.load_dataset") as mock_load:
mock_load.return_value = dummy_dataset
task.load_data()

assert mock_load.called
assert task.dataset is not None
assert len(task.dataset) == 1


@pytest.mark.parametrize(
"task",
[
mteb.get_task("MIRACLRetrievalHardNegatives", languages=["eng"]),
],
)
def test_multilingual_retrieval_load_data(task):
dummy_split = {
"corpus": Dataset.from_dict({"id": ["d1"], "text": ["doc"]}),
"queries": Dataset.from_dict({"id": ["q1"], "text": ["query"]}),
"relevant_docs": {"q1": {"d1": 1}},
"top_ranked": None,
}

with patch("mteb.abstasks.retrieval.RetrievalDatasetLoader.load") as mock_load:
mock_load.return_value = dummy_split
task.load_data()

assert mock_load.called
assert task.dataset is not None
assert len(task.dataset) == 1