diff --git a/mteb/abstasks/retrieval.py b/mteb/abstasks/retrieval.py index 62dbaa49ab..05d0db86b2 100644 --- a/mteb/abstasks/retrieval.py +++ b/mteb/abstasks/retrieval.py @@ -263,7 +263,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: return dataset_path = self.metadata.dataset["path"] - eval_splits = self.metadata.eval_splits + eval_splits = self.eval_splits trust_remote_code = self.metadata.dataset.get("trust_remote_code", False) revision = self.metadata.dataset["revision"] @@ -284,7 +284,7 @@ def _process_data(split: str, hf_subset: str = "default"): ) if self.metadata.is_multilingual: - for lang in self.metadata.eval_langs: + for lang in self.hf_subsets: for split in eval_splits: _process_data(split, lang) else: diff --git a/mteb/abstasks/retrieval_dataset_loaders.py b/mteb/abstasks/retrieval_dataset_loaders.py index 467e6349f8..a0da828e85 100644 --- a/mteb/abstasks/retrieval_dataset_loaders.py +++ b/mteb/abstasks/retrieval_dataset_loaders.py @@ -139,9 +139,9 @@ def _load_dataset_split(self, config: str, num_proc: int) -> Dataset: ) def _load_corpus(self, num_proc: int) -> CorpusDatasetType: - logger.info("Loading Corpus...") - config = f"{self.config}-corpus" if self.config is not None else "corpus" + logger.info("Loading corpus subset: %s", config) + corpus_ds = self._load_dataset_split(config, num_proc) if "_id" in corpus_ds.column_names: corpus_ds = corpus_ds.cast_column("_id", Value("string")).rename_column( @@ -152,9 +152,9 @@ def _load_corpus(self, num_proc: int) -> CorpusDatasetType: return corpus_ds def _load_queries(self, num_proc: int) -> QueryDatasetType: - logger.info("Loading Queries...") - config = f"{self.config}-queries" if self.config is not None else "queries" + logger.info("Loading queries subset: %s", config) + if "query" in self.dataset_configs: config = "query" queries_ds = self._load_dataset_split(config, num_proc) @@ -169,9 +169,9 @@ def _load_queries(self, num_proc: int) -> QueryDatasetType: return queries_ds def _load_qrels(self, num_proc: int) -> RelevantDocumentsType: - logger.info("Loading qrels...") - config = f"{self.config}-qrels" if self.config is not None else "default" + + logger.info("Loading qrels subset: %s", config) if config == "default" and config not in self.dataset_configs: if "qrels" in self.dataset_configs: config = "qrels" @@ -204,11 +204,10 @@ def _load_qrels(self, num_proc: int) -> RelevantDocumentsType: return qrels_dict def _load_top_ranked(self, num_proc: int) -> TopRankedDocumentsType: - logger.info("Loading Top Ranked") - config = ( f"{self.config}-top_ranked" if self.config is not None else "top_ranked" ) + logger.info("Loading top ranked subset: %s", config) top_ranked_ds = self._load_dataset_split(config, num_proc) top_ranked_ds = top_ranked_ds.cast( Features( @@ -228,11 +227,10 @@ def _load_top_ranked(self, num_proc: int) -> TopRankedDocumentsType: return top_ranked_dict def _load_instructions(self, num_proc: int) -> InstructionDatasetType: - logger.info("Loading Instructions") - config = ( f"{self.config}-instruction" if self.config is not None else "instruction" ) + logger.info("Loading instruction subset: %s", config) instructions_ds = self._load_dataset_split(config, num_proc) instructions_ds = instructions_ds.cast( Features( diff --git a/tests/test_tasks/test_load_data.py b/tests/test_tasks/test_load_data.py new file mode 100644 index 0000000000..8b3cbed2df --- /dev/null +++ b/tests/test_tasks/test_load_data.py @@ -0,0 +1,52 @@ +from unittest.mock import patch + +import pytest +from datasets import Dataset, DatasetDict + +import mteb + + +@pytest.mark.parametrize( + "task", + [ + mteb.get_task("DiaBlaBitextMining", hf_subsets=["fr-en"]), + mteb.get_task("AmazonCounterfactualClassification", hf_subsets=["en"]), + mteb.get_task("WikiClusteringP2P", hf_subsets=["bs"]), + mteb.get_task("MultiEURLEXMultilabelClassification", hf_subsets=["en"]), + mteb.get_task("OpusparcusPC", hf_subsets=["en"]), + mteb.get_task("STS17MultilingualVisualSTS", hf_subsets=["en-en"]), + ], +) +def test_multilingual_load_data(task): + dummy_dataset = DatasetDict({"test": Dataset.from_dict({"text": ["test"]})}) + + with patch("mteb.abstasks.abstask.load_dataset") as mock_load: + mock_load.return_value = dummy_dataset + task.load_data() + + assert mock_load.called + assert task.dataset is not None + assert len(task.dataset) == 1 + + +@pytest.mark.parametrize( + "task", + [ + mteb.get_task("MIRACLRetrievalHardNegatives", languages=["eng"]), + ], +) +def test_multilingual_retrieval_load_data(task): + dummy_split = { + "corpus": Dataset.from_dict({"id": ["d1"], "text": ["doc"]}), + "queries": Dataset.from_dict({"id": ["q1"], "text": ["query"]}), + "relevant_docs": {"q1": {"d1": 1}}, + "top_ranked": None, + } + + with patch("mteb.abstasks.retrieval.RetrievalDatasetLoader.load") as mock_load: + mock_load.return_value = dummy_split + task.load_data() + + assert mock_load.called + assert task.dataset is not None + assert len(task.dataset) == 1