Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mteb/abstasks/abstask.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,15 @@ def check_if_dataset_is_superseded(self) -> None:
logger.warning(msg)
warnings.warn(msg)

def dataset_transform(self, num_proc: int = 1):
def dataset_transform(self, num_proc: int = 1, **kwargs: Any) -> None:
"""A transform operations applied to the dataset after loading.

This method is useful when the dataset from Huggingface is not in an `mteb` compatible format.
Override this method if your dataset requires additional transformation.

Args:
num_proc: Number of processes to use for the transformation.
kwargs: Additional keyword arguments passed to the load_dataset function. Keep for forward compatibility.
"""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def load_data(self, **kwargs: Any) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
# Convert to standard format
for lang in self.hf_subsets:
l1, l2 = (l.split("_")[0] for l in lang.split("-"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class BengaliDocumentClassification(AbsTaskClassification):
superseded_by="BengaliDocumentClassification.v2",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{"article": "text", "category": "label"}
)
Expand Down Expand Up @@ -92,7 +92,7 @@ class BengaliDocumentClassificationV2(AbsTaskClassification):
""",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.stratified_subsampling(
self.dataset, seed=self.seed, splits=["test"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CzechProductReviewSentimentClassification(AbsTaskClassification):
)
samples_per_label = 16

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{"comment": "text", "rating_str": "label"}
)
Expand Down Expand Up @@ -99,7 +99,7 @@ class CzechProductReviewSentimentClassificationV2(AbsTaskClassification):
)
samples_per_label = 16

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.stratified_subsampling(
self.dataset, seed=self.seed, splits=["test"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CzechSoMeSentimentClassification(AbsTaskClassification):
)
samples_per_label = 16

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{"comment": "text", "sentiment_int": "label"}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class HinDialectClassification(AbsTaskClassification):
""",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{"folksong": "text", "language": "label"}
)
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ def load_data(self, **kwargs: Any) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.remove_columns(["language", "script"])
self.dataset = self.dataset.rename_columns({"native sentence": "text"})
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class IndicSentimentClassification(AbsTaskClassification):
""",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
label_map = {"Negative": 0, "Positive": 1}
# Convert to standard format
for lang in self.hf_subsets:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class LanguageClassification(AbsTaskClassification):
""",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns({"labels": "label"})
self.dataset = self.stratified_subsampling(
self.dataset, seed=self.seed, splits=["test"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SouthAfricanLangClassification(AbsTaskClassification):
""",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{" text": "text", "lang_id": "label"}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SlovakMovieReviewSentimentClassification(AbsTaskClassification):
superseded_by="SlovakMovieReviewSentimentClassification.v2",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns({"comment": "text"})

self.dataset = self.stratified_subsampling(
Expand Down Expand Up @@ -76,7 +76,7 @@ class SlovakMovieReviewSentimentClassificationV2(AbsTaskClassification):
adapted_from=["SlovakMovieReviewSentimentClassification"],
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.stratified_subsampling(
self.dataset, seed=self.seed, splits=["test"]
)
4 changes: 2 additions & 2 deletions mteb/tasks/classification/swa/swahili_news_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SwahiliNewsClassification(AbsTaskClassification):
superseded_by="SwahiliNewsClassification.v2",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{"content": "text", "category": "label"}
)
Expand Down Expand Up @@ -81,7 +81,7 @@ class SwahiliNewsClassificationV2(AbsTaskClassification):
adapted_from=["SwahiliNewsClassification"],
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.stratified_subsampling(
self.dataset, seed=self.seed, splits=["train"]
)
2 changes: 1 addition & 1 deletion mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TenKGnadClusteringP2PFast(AbsTaskClustering):
adapted_from=["TenKGnadClusteringP2P"],
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
ds = _convert_to_fast(
self.dataset, self.input_column_name, self.label_column_name, self.seed
)
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TenKGnadClusteringS2SFast(AbsTaskClustering):
adapted_from=["TenKGnadClusteringS2S"],
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
ds = _convert_to_fast(
self.dataset, self.input_column_name, self.label_column_name, self.seed
)
Expand Down
4 changes: 2 additions & 2 deletions mteb/tasks/clustering/nob/vg_hierarchical_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class VGHierarchicalClusteringP2P(AbsTaskClustering):
prompt="Identify the categories (e.g. sports) of given articles in Norwegian",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{"article": "sentences", "classes": "labels"}
)
Expand Down Expand Up @@ -92,7 +92,7 @@ class VGHierarchicalClusteringS2S(AbsTaskClustering):
prompt="Identify the categories (e.g. sports) of given articles in Norwegian",
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{"ingress": "sentences", "classes": "labels"}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class CovidDisinformationNLMultiLabelClassification(AbsTaskMultilabelClassificat
},
)

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
labels = [
"q2_label",
"q3_label",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PubChemWikiPairClassification(AbsTaskPairClassification):
""",
)

def dataset_transform(self, num_proc: int = 1) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
_dataset = {}
for lang in self.hf_subsets:
_dataset[lang] = {}
Expand Down
8 changes: 4 additions & 4 deletions mteb/tasks/retrieval/code/code_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text
Expand Down Expand Up @@ -116,7 +116,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text
Expand Down Expand Up @@ -176,7 +176,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text
Expand Down Expand Up @@ -233,7 +233,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = Dict[doc_id, Dict[str, str]] #id => dict with document data like title and text
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/retrieval/dan/dan_fever_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/retrieval/dan/tv2_nordretrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/retrieval/nob/norquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/retrieval/nob/snl_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load_data(self, num_proc: int = 1, **kwargs) -> None:
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
"""And transform to a retrieval dataset, which have the following attributes

self.corpus = dict[doc_id, dict[str, str]] #id => dict with document data like title and text
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/sts/multilingual/sem_rel24_sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ class SemRel24STS(AbsTaskSTS):
min_score = 0
max_score = 1

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
for lang, subset in self.dataset.items():
self.dataset[lang] = subset.rename_column("label", "score")
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ class STSBenchmarkMultilingualSTS(AbsTaskSTS):
min_score = 0
max_score = 5

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
for lang, subset in self.dataset.items():
self.dataset[lang] = subset.rename_column("similarity_score", "score")
2 changes: 1 addition & 1 deletion mteb/tasks/sts/por/assin2_sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Assin2STS(AbsTaskSTS):
min_score = 1
max_score = 5

def dataset_transform(self) -> None:
def dataset_transform(self, num_proc: int = 1, **kwargs) -> None:
self.dataset = self.dataset.rename_columns(
{
"premise": "sentence1",
Expand Down