diff --git a/mteb/abstasks/_data_filter/__init__.py b/mteb/abstasks/_data_filter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mteb/abstasks/_data_filter/filters.py b/mteb/abstasks/_data_filter/filters.py new file mode 100644 index 0000000000..23f12cd820 --- /dev/null +++ b/mteb/abstasks/_data_filter/filters.py @@ -0,0 +1,125 @@ +"""Simplified version of https://gist.github.com/AlexeyVatolin/ea3adc21aa7a767603ff393b22085adc from https://github.com/embeddings-benchmark/mteb/pull/2900""" + +import logging + +import datasets +import pandas as pd +from datasets import Dataset, DatasetDict + +from mteb import TaskMetadata + +logger = logging.getLogger(__name__) + + +def deduplicate(dataset: Dataset, input_column: str) -> Dataset: + """Remove duplicate texts, keeping the first occurrence.""" + unique_texts = set() + indices_to_keep = [] + for i, text in enumerate(dataset[input_column]): + text = text.strip() + if text not in unique_texts: + unique_texts.add(text) + indices_to_keep.append(i) + + logger.info( + f"[deduplicate] removed={len(dataset) - len(indices_to_keep)}/{len(dataset)}" + ) + return dataset.select(indices_to_keep) + + +def filter_empty(dataset: Dataset, input_column: str) -> Dataset: + """Filter out empty or whitespace-only examples.""" + before = len(dataset) + ds = dataset.filter(lambda x: len(x[input_column].strip()) > 0) + logger.info(f"[filter_empty] removed={before - len(ds)}/{before}") + return ds + + +def filter_train_leakage( + train_dataset: Dataset, test_dataset: Dataset, input_column: str +) -> Dataset: + """Remove test examples that appear in training.""" + train_texts = set(train_dataset[input_column]) + before = len(test_dataset) + indices = [ + i + for i, text in enumerate(test_dataset[input_column]) + if text not in train_texts + ] + logger.info(f"[filter_train_leakage] removed={before - len(indices)}/{before}") + return test_dataset.select(indices) + + +def filter_unclear_label( + dataset_dict: DatasetDict, input_column: str, label_column: str +) -> DatasetDict: + """Remove examples where the same text appears with multiple different labels.""" + normalized: dict[str, set[str | tuple[str, ...]]] = {} + logger.debug("[filter_controversial] scanning dataset for label conflicts...") + + for split, ds in dataset_dict.items(): + for text, label in zip(ds[input_column], ds[label_column]): + key = text.strip().lower() + normalized.setdefault(key, set()).add( + label if isinstance(label, (str, int, float)) else tuple(label) + ) + + bad_texts = {t for t, labels in normalized.items() if len(labels) > 1} + logger.info(f"[filter_controversial] Removing {len(bad_texts)} conflicting texts") + + new_dict = {} + for split, ds in dataset_dict.items(): + before = len(ds) + filtered = ds.filter(lambda x: x[input_column].strip().lower() not in bad_texts) + logger.debug( + f"[filter_controversial:{split}] removed={before - len(filtered)}/{before}" + ) + new_dict[split] = filtered + + return DatasetDict(new_dict) + + +def filter_short(dataset: Dataset, input_column: str, min_words: int = 3) -> Dataset: + """Filter out texts with fewer than `min_words`.""" + before = len(dataset) + ds = dataset.filter(lambda x: len(x[input_column].strip().split()) >= min_words) + logger.debug(f"[filter_short] removed={before - len(ds)}/{before}") + return ds + + +def split_train_test( + ds: DatasetDict, + metadata: TaskMetadata, + train_split: str, + label_column: str, +) -> DatasetDict: + if train_split in ds and metadata.eval_splits == train_split: + before = len(ds[train_split]) + logger.info( + f"[split_train_test] eval_splits == train_split; performing split on {before} examples" + ) + ds[train_split] = ds[train_split].cast_column( + label_column, + datasets.ClassLabel(names=list(set(ds[train_split][label_column]))), + ) + label_counts = pd.Series(ds[train_split][label_column]).value_counts() + one_sample_labels = set(label_counts[label_counts == 1].index.tolist()) + + if one_sample_labels: + logger.info( + f"[split_train_test] Removing {len(one_sample_labels)} labels with only one instance" + ) + ds[train_split] = ds[train_split].filter( + lambda x: x[label_column] not in one_sample_labels + ) + + splits = ds[train_split].train_test_split( + test_size=min(2048, before // 2), seed=42, stratify_by_column=label_column + ) + ds = DatasetDict({train_split: splits[train_split], "test": splits["test"]}) + metadata.eval_splits = ["test"] + logger.info( + f"[split_train_test] Train size={len(ds[train_split])}, Test size={len(ds['test'])}" + ) + + return ds diff --git a/mteb/abstasks/_data_filter/task_pipelines.py b/mteb/abstasks/_data_filter/task_pipelines.py new file mode 100644 index 0000000000..f12f10e60e --- /dev/null +++ b/mteb/abstasks/_data_filter/task_pipelines.py @@ -0,0 +1,102 @@ +import logging + +from datasets import DatasetDict + +from mteb import TaskMetadata +from mteb.abstasks import AbsTaskClassification +from mteb.abstasks._data_filter.filters import ( + deduplicate, + filter_empty, + filter_short, + filter_train_leakage, + filter_unclear_label, + split_train_test, +) + +logger = logging.getLogger(__name__) + + +def clean_dataset( + ds: DatasetDict, + metadata: TaskMetadata, + train_split: str, + input_column: str, + label_column: str, + subset: str | None = None, +) -> DatasetDict: + """Apply the full cleaning pipeline with logging.""" + logger.info("[clean_dataset] Starting dataset cleaning pipeline...") + + transforms = [ + ("filter_empty", filter_empty), + ("deduplicate", deduplicate), + ] + + skip_cjk_codes = {"zho", "jpn", "tha", "mya", "cmn"} + logger.info("[clean_dataset] Applying short-text filter") + cur_langs = ( + metadata.eval_langs[subset] + if isinstance(metadata.eval_langs, dict) and subset + else metadata.eval_langs + ) + apply_short = not any(lang.split("-")[0] in skip_cjk_codes for lang in cur_langs) + if apply_short: + logger.info("[clean_dataset] Applying short-text filter") + transforms.append(("filter_short", filter_short)) + + for split in [train_split, *metadata.eval_splits]: + if split not in ds: + logger.warning(f"[clean_dataset] Split '{split}' missing; skipping.") + continue + + for name, fn in transforms: + before = len(ds[split]) + ds[split] = fn(ds[split], input_column=input_column) + logger.info( + f"[clean_dataset:{split}] {name} removed={before - len(ds[split])}" + ) + + ds = split_train_test(ds, metadata, train_split, label_column) + + for split in metadata.eval_splits: + if split == train_split: + continue + before = len(ds[split]) + ds[split] = filter_train_leakage(ds[train_split], ds[split], input_column) + logger.info( + f"[clean_dataset:{split}] leakage_removed={before - len(ds[split])}" + ) + + ds = filter_unclear_label(ds, input_column=input_column, label_column=label_column) + + logger.info("[clean_dataset] Cleaning pipeline complete.") + return ds + + +def process_classification( + task: AbsTaskClassification, +) -> DatasetDict | dict[str, DatasetDict]: + """Process classification task dataset(s) with cleaning pipeline.""" + if not task.data_loaded: + task.load_data() + if isinstance(task.dataset, DatasetDict): + return clean_dataset( + task.dataset, + task.metadata, + task.train_split, + task.input_column_name, + task.label_column_name, + subset=None, + ) + + new_ds = {} + for subset in task.dataset: + new_ds[subset] = clean_dataset( + task.dataset[subset], + task.metadata, + task.train_split, + task.input_column_name, + task.label_column_name, + subset=subset, + ) + return new_ds