From 773e688f6dc9ec79c2096df682a2e7e355f361c4 Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Mon, 26 Jan 2026 01:17:45 +0300 Subject: [PATCH] fix dataset transform --- mteb/abstasks/abstask.py | 3 ++- .../bitext_mining/multilingual/bible_nlp_bitext_mining.py | 2 +- .../classification/ben/bengali_document_classification.py | 4 ++-- .../ces/czech_product_review_sentiment_classification.py | 4 ++-- .../ces/czech_so_me_sentiment_classification.py | 2 +- .../multilingual/hin_dialect_classification.py | 2 +- .../multilingual/indic_lang_classification.py | 2 +- .../multilingual/indic_sentiment_classification.py | 2 +- .../multilingual/language_classification.py | 2 +- .../multilingual/south_african_lang_classification.py | 2 +- .../slk/slovak_movie_review_sentiment_classification.py | 4 ++-- .../classification/swa/swahili_news_classification.py | 4 ++-- mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py | 2 +- mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py | 2 +- mteb/tasks/clustering/nob/vg_hierarchical_clustering.py | 4 ++-- .../covid_disinformation_nl_multi_label_classification.py | 2 +- .../multilingual/pub_chem_wiki_pair_classification.py | 2 +- mteb/tasks/retrieval/code/code_rag.py | 8 ++++---- mteb/tasks/retrieval/dan/dan_fever_retrieval.py | 2 +- mteb/tasks/retrieval/dan/tv2_nordretrieval.py | 2 +- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py | 2 +- mteb/tasks/retrieval/nob/norquad.py | 2 +- mteb/tasks/retrieval/nob/snl_retrieval.py | 2 +- mteb/tasks/sts/multilingual/sem_rel24_sts.py | 2 +- .../sts/multilingual/sts_benchmark_multilingual_sts.py | 2 +- mteb/tasks/sts/por/assin2_sts.py | 2 +- 26 files changed, 35 insertions(+), 34 deletions(-) diff --git a/mteb/abstasks/abstask.py b/mteb/abstasks/abstask.py index f9d6b6d599..97797b8b96 100644 --- a/mteb/abstasks/abstask.py +++ b/mteb/abstasks/abstask.py @@ -116,7 +116,7 @@ def check_if_dataset_is_superseded(self) -> None: logger.warning(msg) warnings.warn(msg) - def dataset_transform(self, num_proc: int = 1): + def dataset_transform(self, num_proc: int = 1, **kwargs: Any) -> None: """A transform operations applied to the dataset after loading. This method is useful when the dataset from Huggingface is not in an `mteb` compatible format. @@ -124,6 +124,7 @@ def dataset_transform(self, num_proc: int = 1): Args: num_proc: Number of processes to use for the transformation. + kwargs: Additional keyword arguments passed to the load_dataset function. Keep for forward compatibility. """ pass diff --git a/mteb/tasks/bitext_mining/multilingual/bible_nlp_bitext_mining.py b/mteb/tasks/bitext_mining/multilingual/bible_nlp_bitext_mining.py index 1d2961ceb9..4f7efed0cf 100644 --- a/mteb/tasks/bitext_mining/multilingual/bible_nlp_bitext_mining.py +++ b/mteb/tasks/bitext_mining/multilingual/bible_nlp_bitext_mining.py @@ -914,7 +914,7 @@ def load_data(self, **kwargs: Any) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: # Convert to standard format for lang in self.hf_subsets: l1, l2 = (l.split("_")[0] for l in lang.split("-")) diff --git a/mteb/tasks/classification/ben/bengali_document_classification.py b/mteb/tasks/classification/ben/bengali_document_classification.py index 5a40c3ba0c..a3c941950e 100644 --- a/mteb/tasks/classification/ben/bengali_document_classification.py +++ b/mteb/tasks/classification/ben/bengali_document_classification.py @@ -43,7 +43,7 @@ class BengaliDocumentClassification(AbsTaskClassification): superseded_by="BengaliDocumentClassification.v2", ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {"article": "text", "category": "label"} ) @@ -92,7 +92,7 @@ class BengaliDocumentClassificationV2(AbsTaskClassification): """, ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.stratified_subsampling( self.dataset, seed=self.seed, splits=["test"] ) diff --git a/mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py b/mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py index cbd7393bf6..74e0ff7a40 100644 --- a/mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +++ b/mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py @@ -46,7 +46,7 @@ class CzechProductReviewSentimentClassification(AbsTaskClassification): ) samples_per_label = 16 - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {"comment": "text", "rating_str": "label"} ) @@ -99,7 +99,7 @@ class CzechProductReviewSentimentClassificationV2(AbsTaskClassification): ) samples_per_label = 16 - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.stratified_subsampling( self.dataset, seed=self.seed, splits=["test"] ) diff --git a/mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py b/mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py index 904b50314b..8e5fb01164 100644 --- a/mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +++ b/mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py @@ -46,7 +46,7 @@ class CzechSoMeSentimentClassification(AbsTaskClassification): ) samples_per_label = 16 - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {"comment": "text", "sentiment_int": "label"} ) diff --git a/mteb/tasks/classification/multilingual/hin_dialect_classification.py b/mteb/tasks/classification/multilingual/hin_dialect_classification.py index c1db7cb4a5..b4515a5de6 100644 --- a/mteb/tasks/classification/multilingual/hin_dialect_classification.py +++ b/mteb/tasks/classification/multilingual/hin_dialect_classification.py @@ -60,7 +60,7 @@ class HinDialectClassification(AbsTaskClassification): """, ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {"folksong": "text", "language": "label"} ) diff --git a/mteb/tasks/classification/multilingual/indic_lang_classification.py b/mteb/tasks/classification/multilingual/indic_lang_classification.py index 031f6f1cd9..3f8c9c8c9b 100644 --- a/mteb/tasks/classification/multilingual/indic_lang_classification.py +++ b/mteb/tasks/classification/multilingual/indic_lang_classification.py @@ -137,6 +137,6 @@ def load_data(self, **kwargs: Any) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.remove_columns(["language", "script"]) self.dataset = self.dataset.rename_columns({"native sentence": "text"}) diff --git a/mteb/tasks/classification/multilingual/indic_sentiment_classification.py b/mteb/tasks/classification/multilingual/indic_sentiment_classification.py index 7c5873ab84..7ec3a29b9d 100644 --- a/mteb/tasks/classification/multilingual/indic_sentiment_classification.py +++ b/mteb/tasks/classification/multilingual/indic_sentiment_classification.py @@ -52,7 +52,7 @@ class IndicSentimentClassification(AbsTaskClassification): """, ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: label_map = {"Negative": 0, "Positive": 1} # Convert to standard format for lang in self.hf_subsets: diff --git a/mteb/tasks/classification/multilingual/language_classification.py b/mteb/tasks/classification/multilingual/language_classification.py index def6b440e6..1b5ae02df7 100644 --- a/mteb/tasks/classification/multilingual/language_classification.py +++ b/mteb/tasks/classification/multilingual/language_classification.py @@ -66,7 +66,7 @@ class LanguageClassification(AbsTaskClassification): """, ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns({"labels": "label"}) self.dataset = self.stratified_subsampling( self.dataset, seed=self.seed, splits=["test"] diff --git a/mteb/tasks/classification/multilingual/south_african_lang_classification.py b/mteb/tasks/classification/multilingual/south_african_lang_classification.py index 99bda1453b..fae3133985 100644 --- a/mteb/tasks/classification/multilingual/south_african_lang_classification.py +++ b/mteb/tasks/classification/multilingual/south_african_lang_classification.py @@ -49,7 +49,7 @@ class SouthAfricanLangClassification(AbsTaskClassification): """, ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {" text": "text", "lang_id": "label"} ) diff --git a/mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py b/mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py index 3849ecd201..1b5a3f2905 100644 --- a/mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +++ b/mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py @@ -35,7 +35,7 @@ class SlovakMovieReviewSentimentClassification(AbsTaskClassification): superseded_by="SlovakMovieReviewSentimentClassification.v2", ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns({"comment": "text"}) self.dataset = self.stratified_subsampling( @@ -76,7 +76,7 @@ class SlovakMovieReviewSentimentClassificationV2(AbsTaskClassification): adapted_from=["SlovakMovieReviewSentimentClassification"], ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.stratified_subsampling( self.dataset, seed=self.seed, splits=["test"] ) diff --git a/mteb/tasks/classification/swa/swahili_news_classification.py b/mteb/tasks/classification/swa/swahili_news_classification.py index 41880ee5e1..69c05c4084 100644 --- a/mteb/tasks/classification/swa/swahili_news_classification.py +++ b/mteb/tasks/classification/swa/swahili_news_classification.py @@ -37,7 +37,7 @@ class SwahiliNewsClassification(AbsTaskClassification): superseded_by="SwahiliNewsClassification.v2", ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {"content": "text", "category": "label"} ) @@ -81,7 +81,7 @@ class SwahiliNewsClassificationV2(AbsTaskClassification): adapted_from=["SwahiliNewsClassification"], ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.stratified_subsampling( self.dataset, seed=self.seed, splits=["train"] ) diff --git a/mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py b/mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py index 7bc5c18c59..c11e428de8 100644 --- a/mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py +++ b/mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py @@ -63,7 +63,7 @@ class TenKGnadClusteringP2PFast(AbsTaskClustering): adapted_from=["TenKGnadClusteringP2P"], ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: ds = _convert_to_fast( self.dataset, self.input_column_name, self.label_column_name, self.seed ) diff --git a/mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py b/mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py index 8912cd67c0..d77ec450a3 100644 --- a/mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py +++ b/mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py @@ -63,7 +63,7 @@ class TenKGnadClusteringS2SFast(AbsTaskClustering): adapted_from=["TenKGnadClusteringS2S"], ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: ds = _convert_to_fast( self.dataset, self.input_column_name, self.label_column_name, self.seed ) diff --git a/mteb/tasks/clustering/nob/vg_hierarchical_clustering.py b/mteb/tasks/clustering/nob/vg_hierarchical_clustering.py index 0d6b36542e..505f4740d1 100644 --- a/mteb/tasks/clustering/nob/vg_hierarchical_clustering.py +++ b/mteb/tasks/clustering/nob/vg_hierarchical_clustering.py @@ -45,7 +45,7 @@ class VGHierarchicalClusteringP2P(AbsTaskClustering): prompt="Identify the categories (e.g. sports) of given articles in Norwegian", ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {"article": "sentences", "classes": "labels"} ) @@ -92,7 +92,7 @@ class VGHierarchicalClusteringS2S(AbsTaskClustering): prompt="Identify the categories (e.g. sports) of given articles in Norwegian", ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( {"ingress": "sentences", "classes": "labels"} ) diff --git a/mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py b/mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py index 0274984150..3d70cae267 100644 --- a/mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +++ b/mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py @@ -66,7 +66,7 @@ class CovidDisinformationNLMultiLabelClassification(AbsTaskMultilabelClassificat }, ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: labels = [ "q2_label", "q3_label", diff --git a/mteb/tasks/pair_classification/multilingual/pub_chem_wiki_pair_classification.py b/mteb/tasks/pair_classification/multilingual/pub_chem_wiki_pair_classification.py index f02cb4db76..81c50e1390 100644 --- a/mteb/tasks/pair_classification/multilingual/pub_chem_wiki_pair_classification.py +++ b/mteb/tasks/pair_classification/multilingual/pub_chem_wiki_pair_classification.py @@ -60,7 +60,7 @@ class PubChemWikiPairClassification(AbsTaskPairClassification): """, ) - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: _dataset = {} for lang in self.hf_subsets: _dataset[lang] = {} diff --git a/mteb/tasks/retrieval/code/code_rag.py b/mteb/tasks/retrieval/code/code_rag.py index 6f2b1eb6b5..15000cf02a 100644 --- a/mteb/tasks/retrieval/code/code_rag.py +++ b/mteb/tasks/retrieval/code/code_rag.py @@ -59,7 +59,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text @@ -116,7 +116,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text @@ -176,7 +176,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text @@ -233,7 +233,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text diff --git a/mteb/tasks/retrieval/dan/dan_fever_retrieval.py b/mteb/tasks/retrieval/dan/dan_fever_retrieval.py index 7095901908..77155ea484 100644 --- a/mteb/tasks/retrieval/dan/dan_fever_retrieval.py +++ b/mteb/tasks/retrieval/dan/dan_fever_retrieval.py @@ -55,7 +55,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text diff --git a/mteb/tasks/retrieval/dan/tv2_nordretrieval.py b/mteb/tasks/retrieval/dan/tv2_nordretrieval.py index 80292c04d2..07cb6f5fa9 100644 --- a/mteb/tasks/retrieval/dan/tv2_nordretrieval.py +++ b/mteb/tasks/retrieval/dan/tv2_nordretrieval.py @@ -68,7 +68,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text diff --git a/mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py b/mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py index cfe2266d1e..2449a77c96 100644 --- a/mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +++ b/mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py @@ -44,7 +44,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text diff --git a/mteb/tasks/retrieval/nob/norquad.py b/mteb/tasks/retrieval/nob/norquad.py index 6f439993f5..af55744fdd 100644 --- a/mteb/tasks/retrieval/nob/norquad.py +++ b/mteb/tasks/retrieval/nob/norquad.py @@ -58,7 +58,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text diff --git a/mteb/tasks/retrieval/nob/snl_retrieval.py b/mteb/tasks/retrieval/nob/snl_retrieval.py index 1f13690838..16b49673bc 100644 --- a/mteb/tasks/retrieval/nob/snl_retrieval.py +++ b/mteb/tasks/retrieval/nob/snl_retrieval.py @@ -45,7 +45,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None: self.dataset_transform() self.data_loaded = True - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: """And transform to a retrieval dataset, which have the following attributes self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text diff --git a/mteb/tasks/sts/multilingual/sem_rel24_sts.py b/mteb/tasks/sts/multilingual/sem_rel24_sts.py index e4f567964b..7b790f408e 100644 --- a/mteb/tasks/sts/multilingual/sem_rel24_sts.py +++ b/mteb/tasks/sts/multilingual/sem_rel24_sts.py @@ -66,6 +66,6 @@ class SemRel24STS(AbsTaskSTS): min_score = 0 max_score = 1 - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: for lang, subset in self.dataset.items(): self.dataset[lang] = subset.rename_column("label", "score") diff --git a/mteb/tasks/sts/multilingual/sts_benchmark_multilingual_sts.py b/mteb/tasks/sts/multilingual/sts_benchmark_multilingual_sts.py index 9a85787c2d..a3392dc417 100644 --- a/mteb/tasks/sts/multilingual/sts_benchmark_multilingual_sts.py +++ b/mteb/tasks/sts/multilingual/sts_benchmark_multilingual_sts.py @@ -56,6 +56,6 @@ class STSBenchmarkMultilingualSTS(AbsTaskSTS): min_score = 0 max_score = 5 - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: for lang, subset in self.dataset.items(): self.dataset[lang] = subset.rename_column("similarity_score", "score") diff --git a/mteb/tasks/sts/por/assin2_sts.py b/mteb/tasks/sts/por/assin2_sts.py index b0b229fccd..4b4cd9f9b3 100644 --- a/mteb/tasks/sts/por/assin2_sts.py +++ b/mteb/tasks/sts/por/assin2_sts.py @@ -39,7 +39,7 @@ class Assin2STS(AbsTaskSTS): min_score = 1 max_score = 5 - def dataset_transform(self) -> None: + def dataset_transform(self, num_proc: int = 1, **kwargs) -> None: self.dataset = self.dataset.rename_columns( { "premise": "sentence1",