Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions mteb/abstasks/pair_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
if TYPE_CHECKING:
from pathlib import Path

from numpy.typing import NDArray

from mteb._evaluators.pair_classification_evaluator import (
PairClassificationDistances,
)
Expand All @@ -36,7 +38,6 @@
TextStatistics,
)


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -138,7 +139,7 @@ def _compute_metrics(
self, similarity_scores: PairClassificationDistances, labels: list[int]
) -> dict[str, float]:
logger.info("Computing metrics...")
np_labels = np.asarray(labels)
np_labels: NDArray[np.int64] = np.asarray(labels, dtype=np.int64)
output_scores = {}
max_scores = defaultdict(list)
for short_name, scores, reverse in [
Expand Down Expand Up @@ -281,7 +282,10 @@ def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
)

def _compute_metrics_values(
self, scores: list[float], labels: np.ndarray, high_score_more_similar: bool
self,
scores: list[float],
labels: NDArray[np.int64],
high_score_more_similar: bool,
) -> dict[str, float]:
"""Compute the metrics for the given scores and labels.

Expand Down Expand Up @@ -315,15 +319,18 @@ def _compute_metrics_values(
)

def _find_best_acc_and_threshold(
self, scores: list[float], labels: np.ndarray, high_score_more_similar: bool
self,
scores: list[float],
labels: NDArray[np.int64],
high_score_more_similar: bool,
) -> tuple[float, float]:
rows = list(zip(scores, labels))
rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)

max_acc = 0
best_threshold = -1.0
positive_so_far = 0
remaining_negatives = sum(np.array(labels) == 0)
remaining_negatives = sum(labels == 0)

for i in range(len(rows) - 1):
score, label = rows[i]
Expand All @@ -339,10 +346,9 @@ def _find_best_acc_and_threshold(
return max_acc, best_threshold

def _find_best_f1_and_threshold(
self, scores, labels, high_score_more_similar: bool
self, scores, labels: NDArray[np.int64], high_score_more_similar: bool
) -> tuple[float, float, float, float]:
scores = np.asarray(scores)
labels = np.asarray(labels)

rows = list(zip(scores, labels))

Expand Down
36 changes: 2 additions & 34 deletions mteb/tasks/pair_classification/fas/fars_tail.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import datasets

from mteb.abstasks.pair_classification import AbsTaskPairClassification
from mteb.abstasks.task_metadata import TaskMetadata

Expand All @@ -8,8 +6,8 @@ class FarsTail(AbsTaskPairClassification):
metadata = TaskMetadata(
name="FarsTail",
dataset={
"path": "azarijafari/FarsTail",
"revision": "7335288588f14e5a687d97fc979194c2abe6f4e7",
"path": "mteb/FarsTail",
"revision": "0fa0863dc160869b5a2d78803b4440ea3c671ff5",
},
description="This dataset, named FarsTail, includes 10,367 samples which are provided in both the Persian language as well as the indexed format to be useful for non-Persian researchers. The samples are generated from 3,539 multiple-choice questions with the least amount of annotator interventions in a way similar to the SciTail dataset",
reference="https://link.springer.com/article/10.1007/s00500-023-08959-3",
Expand Down Expand Up @@ -37,33 +35,3 @@ class FarsTail(AbsTaskPairClassification):
}
""", # after removing neutral
)

def load_data(self, num_proc: int = 1, **kwargs) -> None:
if self.data_loaded:
return
path = self.metadata.dataset["path"]
revision = self.metadata.dataset["revision"]
data_files = {
"test": f"https://huggingface.co/datasets/{path}/resolve/{revision}/data/Test-word.csv"
}
self.dataset = datasets.load_dataset(
"csv", data_files=data_files, delimiter="\t"
)
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self, num_proc: int = 1):
_dataset = {}
self.dataset = self.dataset.filter(lambda x: x["label"] != "n")
self.dataset = self.dataset.map(
lambda example: {"label": 1 if example["label"] == "e" else 0}
)
for split in self.metadata.eval_splits:
_dataset[split] = [
{
"sentence1": self.dataset[split]["premise"],
"sentence2": self.dataset[split]["hypothesis"],
"labels": self.dataset[split]["label"],
}
]
self.dataset = _dataset
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ extend-exclude = [
"docs/references.bib",
"mteb/models/model_implementations/gme_v_models.py", # video_grid_thw `thw`
"mteb/models/model_implementations/vista_models.py", # self.normlized: in visual bge
"mteb/models/model_implementations/salesforce_models.py", # multiligual in paper title
"tests/mock_tasks.py", # "denne her matche ikke den ovenstående",
"mteb/models/model_implementations/kalm_models.py", # prompt: classify ist topic",
"mteb/tasks/reranking/eng/built_bench_reranking.py", # prompt: descriptions from buit asset
Expand Down
1 change: 1 addition & 0 deletions tests/test_search_index/test_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_retrieval_backends(
task: AbsTaskRetrieval, similarity: ScoringFunction, tmp_path: Path
):
"""Test different retrieval backends for retrieval and reranking tasks."""
pytest.importorskip("faiss", reason="faiss is not installed")
model = mteb.get_model("baseline/random-encoder-baseline")
model_meta = deepcopy(model.mteb_model_meta)
model_meta.similarity_fn_name = similarity
Expand Down