Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a09734f
add pytyped
Samoed Nov 11, 2025
98eab29
start typing
Samoed Nov 11, 2025
e028aea
finish evaluators
Samoed Nov 11, 2025
86e7efd
add more types
Samoed Nov 11, 2025
84ab864
Update mteb/results/benchmark_results.py
Samoed Nov 12, 2025
5d2afaf
Merge branch 'main' into typing
Samoed Dec 8, 2025
21ff289
apply comments
Samoed Dec 8, 2025
f9d3035
continue typechecking
Samoed Dec 8, 2025
b70886c
Merge branch 'main' into typing
Samoed Dec 15, 2025
c9dedc7
fix typehint
Samoed Dec 15, 2025
222658c
Merge branch 'main' into typing
Samoed Dec 15, 2025
fed4ecc
typechecking
Samoed Dec 15, 2025
5cfc64b
fix tests
Samoed Dec 15, 2025
0c45374
fix type errors again
Samoed Dec 15, 2025
234fdac
fix cache
Samoed Dec 16, 2025
59a65ef
Merge branch 'main' into typing
Samoed Dec 21, 2025
39e09dd
add more types
Samoed Dec 22, 2025
20fa646
fix method
Samoed Dec 22, 2025
8d4daa5
Merge branch 'main' into typing
Samoed Dec 22, 2025
d396270
roll back pyproject
Samoed Dec 22, 2025
794a32f
activate PGH
Samoed Dec 22, 2025
aed114d
install more types
Samoed Dec 22, 2025
fac5c58
almost finish
Samoed Dec 22, 2025
651c0e0
fix search wrappers
Samoed Dec 22, 2025
d0c061f
add ci
Samoed Dec 22, 2025
9846e66
fix tests
Samoed Dec 22, 2025
1858d7e
fix 3.10 types
Samoed Dec 22, 2025
39cbc21
rollback overload
Samoed Dec 22, 2025
d972ae5
Merge branch 'main' into typing
Samoed Dec 26, 2025
3b3f798
fixes after merge
Samoed Dec 26, 2025
93d2230
change to iterable
Samoed Dec 26, 2025
9b3c1d4
add fixes
Samoed Dec 26, 2025
3d8c073
remove summarization scores hint
Samoed Dec 26, 2025
ed773c0
simplify deprecated_evaluator
Samoed Dec 26, 2025
db47e14
simplify model conversion
Samoed Dec 26, 2025
00bac9c
add comment for typechecking
Samoed Dec 26, 2025
cb8cf8e
remove casts
Samoed Dec 26, 2025
1c55982
Merge branch 'main' into typing
Samoed Dec 27, 2025
f33c354
remove duplicated function
Samoed Dec 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/typechecking.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Typechecking

on:
push:
branches: [main]
pull_request:


jobs:
typecheck:
runs-on: ubuntu-latest
steps:
- name: Free disk space
run: |
sudo rm -rf \
"$AGENT_TOOLSDIRECTORY" \
/opt/ghc \
/opt/google/chrome \
/opt/microsoft/msedge \
/opt/microsoft/powershell \
/opt/pipx \
/usr/lib/mono \
/usr/local/julia* \
/usr/local/lib/android \
/usr/local/lib/node_modules \
/usr/local/share/chromium \
/usr/local/share/powershell \
/usr/local/share/powershell \
/usr/share/dotnet \
/usr/share/swift
docker system prune -af

- uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: "3.10"

- name: Dependencies
run: |
make install-for-tests
pip install -e . --group typing

- name: Build and Deploy
run: |
make typecheck
19 changes: 9 additions & 10 deletions docs/mmteb/validate_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from pathlib import Path
from typing import Optional

from jsonlines import Reader
from pydantic import BaseModel, ConfigDict, Field, ValidationError, conint, constr
Expand All @@ -21,17 +20,17 @@
class JsonObject(BaseModel):
model_config = ConfigDict(extra="forbid")
GitHub: constr(min_length=1)
new_dataset: Optional[conint(ge=1)] = Field(alias="New dataset", default=None) # noqa
new_task: Optional[conint(ge=2)] = Field(alias="New task", default=None) # noqa
dataset_annotations: Optional[conint(ge=1)] = Field( # noqa
new_dataset: conint(ge=1) | None = Field(alias="New dataset", default=None)
new_task: conint(ge=2) | None = Field(alias="New task", default=None)
dataset_annotations: conint(ge=1) | None = Field(
alias="Dataset annotations", default=None
)
bug_fixes: Optional[conint(ge=1)] = Field(alias="Bug fixes", default=None) # noqa
running_models: Optional[conint(ge=1)] = Field(alias="Running Models", default=None) # noqa
review_pr: Optional[conint(ge=2)] = Field(alias="Review PR", default=None) # noqa
paper_writing: Optional[int] = Field(alias="Paper writing", default=None) # noqa
Ideation: Optional[int] = None # noqa
Coordination: Optional[int] = None # noqa
bug_fixes: conint(ge=1) | None = Field(alias="Bug fixes", default=None)
running_models: conint(ge=1) | None = Field(alias="Running Models", default=None)
review_pr: conint(ge=2) | None = Field(alias="Review PR", default=None)
paper_writing: int | None = Field(alias="Paper writing", default=None)
Ideation: int | None = None
Coordination: int | None = None


def check_max_points(obj: JsonObject, commit_n: str):
Expand Down
25 changes: 10 additions & 15 deletions mteb/_create_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from collections.abc import Callable
from typing import Any, cast

Expand Down Expand Up @@ -113,11 +114,8 @@ def _create_text_dataloader_for_queries(
)


_warned_about_user_role = False


def _convert_conv_history_to_query(
row: dict[str, list[str] | Conversation],
row: dict[str, str | list[str] | Conversation],
) -> dict[str, str | Conversation]:
"""Convert a conversation history to a single query string.

Expand All @@ -127,21 +125,18 @@ def _convert_conv_history_to_query(
Returns:
The updated row with the "query" and "text" fields set to the conversation string, and the "conversation" field set to the list of ConversationTurn.
"""
global _warned_about_user_role

conversation = row["text"]
# if it's a list of strings, just join them
if isinstance(conversation, list) and isinstance(conversation[0], str):
conversation = cast(list[str], conversation)
conv_str = "; ".join(conversation)
conversation_ = cast(list[str], conversation)
conv_str = "; ".join(conversation_)
current_conversation = [
ConversationTurn(role="user", content=message) for message in conversation
ConversationTurn(role="user", content=message) for message in conversation_
]
if not _warned_about_user_role:
logger.warning(
"Conversations are a list of strings. Used 'user' role for all turns."
)
_warned_about_user_role = True
warnings.warn(
"Conversations are a list of strings. Used 'user' role for all turns.",
category=UserWarning,
)
# otherwise, it's a list of dictionaries, which we need to convert to strings
elif isinstance(conversation, list) and isinstance(conversation[0], dict):
conv = []
Expand Down Expand Up @@ -178,7 +173,7 @@ def _convert_conv_history_to_query(

row["text"] = conv_str
row["conversation"] = current_conversation
return row
return cast(dict[str, str | list[ConversationTurn]], row)


def _create_dataloader_for_queries_conversation(
Expand Down
5 changes: 1 addition & 4 deletions mteb/_evaluators/any_sts_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def __init__(
self.input2_prompt_type = input2_prompt_type

def __call__(
self,
model: EncoderProtocol,
*,
encode_kwargs: dict[str, Any],
self, model: EncoderProtocol, *, encode_kwargs: dict[str, Any]
) -> STSEvaluatorScores:
logger.info("Running semantic similarity - Encoding samples (1/2)")
embeddings1 = model.encode(
Expand Down
3 changes: 2 additions & 1 deletion mteb/_evaluators/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from typing import Any

from mteb.abstasks.abstask import _set_seed
Expand All @@ -18,7 +19,7 @@ def __init__(self, seed: int = 42, **kwargs: Any) -> None:
@abstractmethod
def __call__(
self, model: EncoderProtocol, *, encode_kwargs: dict[str, Any]
) -> dict[str, float]:
) -> Mapping[str, float] | Iterable[Any]:
"""This is called during training to evaluate the model.

It returns scores.
Expand Down
11 changes: 5 additions & 6 deletions mteb/_evaluators/image/imagetext_pairclassification_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import torch
Expand Down Expand Up @@ -61,8 +62,8 @@ class ImageTextPairClassificationEvaluator(Evaluator):
def __init__(
self,
dataset,
images_column_names: str | list[str],
texts_column_names: str | list[str],
images_column_names: str | Sequence[str],
texts_column_names: str | Sequence[str],
num_images_per_sample: int,
num_texts_per_sample: int,
task_metadata: TaskMetadata,
Expand All @@ -82,10 +83,8 @@ def __init__(
self.hf_split = hf_split
self.hf_subset = hf_subset

def __call__(
self,
model: EncoderProtocol,
encode_kwargs: dict[str, Any],
def __call__( # type: ignore[override]
self, model: EncoderProtocol, *, encode_kwargs: dict[str, Any]
) -> list[torch.Tensor]:
images = []
if isinstance(self.images_column_names, str):
Expand Down
4 changes: 3 additions & 1 deletion mteb/_evaluators/pair_classification_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def _encode_unique_texts(
hf_subset: str,
**encode_kwargs: Any,
) -> np.ndarray:
index_map, all_unique_texts, all_texts_indexes = {}, [], []
index_map = {}
all_unique_texts: list[str] = []
all_texts_indexes = []
for text in all_texts:
text_hash = hash(text)
if text_hash not in index_map:
Expand Down
33 changes: 17 additions & 16 deletions mteb/_evaluators/retrieval_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from collections import defaultdict
from collections.abc import Mapping
from typing import Any

import numpy as np
Expand All @@ -15,7 +16,7 @@

def mrr(
qrels: RelevantDocumentsType,
results: dict[str, dict[str, float]],
results: Mapping[str, Mapping[str, float]],
k_values: list[int],
) -> dict[str, list[float]]:
mrr_metrics = defaultdict(list)
Expand All @@ -32,7 +33,7 @@ def mrr(
doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
}
for k in k_values:
rr = 0
rr = 0.0
for rank, hit in enumerate(top_hits[query_id][0:k]):
if hit[0] in query_relevant_docs:
rr = 1.0 / (rank + 1)
Expand All @@ -45,8 +46,8 @@ def recall_cap(
qrels: RelevantDocumentsType,
results: dict[str, dict[str, float]],
k_values: list[int],
) -> dict[str, list[float]]:
capped_recall = defaultdict(list)
) -> dict[str, list[float | None]]:
capped_recall: dict[str, list[float | None]] = defaultdict(list)

k_max = max(k_values)

Expand Down Expand Up @@ -188,7 +189,7 @@ def evaluate_p_mrr_change(
Returns:
A dictionary with the scores, including "p-MRR", "og" and "changed" keys.
"""
followir_scores = defaultdict(dict)
followir_scores: dict[str, float | dict[str, float]] = defaultdict(dict)

qrels_sep = {
"og": {k: v for k, v in qrels.items() if k.endswith("-og")},
Expand Down Expand Up @@ -227,7 +228,7 @@ def evaluate_p_mrr_change(
ndcg, _map, recall, precision, naucs, avg_mrr, naucs_mrr, cv_recall, {}
)
for key, value in scores_dict.items():
followir_scores[name][key] = value
followir_scores[name][key] = value # type: ignore[index]

return followir_scores

Expand All @@ -254,8 +255,8 @@ def confidence_scores(sim_scores: list[float]) -> dict[str, float]:
sim_scores_sorted = sorted(sim_scores)[::-1]

cs_max = sim_scores_sorted[0]
cs_std = np.std(sim_scores)
cs_diff1 = None
cs_std = float(np.std(sim_scores))
cs_diff1 = 0.0
if len(sim_scores) > 1:
cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1]
elif len(sim_scores) == 1:
Expand Down Expand Up @@ -410,7 +411,7 @@ def make_score_dict(
cv_recall: dict[str, float],
task_scores: dict[str, float],
previous_results_model_meta: dict[str, Any] | None = None,
) -> dict[str, float]:
) -> dict[str, Any]:
return {
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
Expand Down Expand Up @@ -528,7 +529,7 @@ def max_over_subqueries(


def calculate_retrieval_scores(
results: dict[str, dict[str, float]],
results: Mapping[str, Mapping[str, float]],
qrels: RelevantDocumentsType,
k_values: list[int],
skip_first_result: bool = False,
Expand Down Expand Up @@ -576,7 +577,7 @@ def calculate_retrieval_scores(


def evaluate_abstention(
results: dict[str, dict[str, float]],
results: Mapping[str, Mapping[str, float]],
metric_scores: dict[str, list[float]],
) -> dict[str, float]:
"""Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997
Expand All @@ -591,21 +592,21 @@ def evaluate_abstention(
all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())]
all_conf_scores = [confidence_scores(sim_scores) for sim_scores in all_sim_scores]
conf_fcts = list(all_conf_scores[0].keys())
all_conf_scores = {
all_conf_scores_ = {
fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts
}
metric_scores = {k: np.array(v) for k, v in metric_scores.items()}
metric_scores_ = {k: np.array(v) for k, v in metric_scores.items()}
naucs = {}

for metric_name, scores in metric_scores.items():
for fct, conf_scores in all_conf_scores.items():
for metric_name, scores in metric_scores_.items():
for fct, conf_scores in all_conf_scores_.items():
naucs[f"nAUC_{metric_name}_{fct}"] = nauc(conf_scores, scores)

return naucs


def calculate_cv_recall(
results: dict[str, dict[str, float]],
results: Mapping[str, Mapping[str, float]],
qrels: RelevantDocumentsType,
k_values: list[int],
skip_first_result: bool = False,
Expand Down
17 changes: 9 additions & 8 deletions mteb/_evaluators/sklearn_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Protocol
from typing import Any, Protocol, cast

import numpy as np
from datasets import Dataset
Expand All @@ -9,19 +9,19 @@
from mteb._create_dataloaders import create_dataloader
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import BatchedInput
from mteb.types import Array, BatchedInput

from .evaluator import Evaluator

logger = logging.getLogger(__name__)


class SklearnModelProtocol(Protocol):
def fit(self, X: np.ndarray, y: np.ndarray | list[int]) -> None: ... # noqa: N803
def predict(self, X: np.ndarray) -> np.ndarray: ... # noqa: N803
def fit(self, X: Array, y: np.ndarray | list[int]) -> None: ... # noqa: N803
def predict(self, X: Array) -> np.ndarray: ... # noqa: N803
def get_params(self) -> dict[str, Any]: ...
def set_params(self, **kwargs: dict[str, Any]) -> Self: ...
def score(self, X: np.ndarray, y: np.ndarray | list[int]) -> float: ... # noqa: N803
def set_params(self, random_state: int, **kwargs: dict[str, Any]) -> Self: ...
def score(self, X: Array, y: np.ndarray | list[int]) -> float: ... # noqa: N803


class SklearnEvaluator(Evaluator):
Expand Down Expand Up @@ -71,8 +71,8 @@ def __call__( # type: ignore[override]
model: EncoderProtocol,
*,
encode_kwargs: dict[str, Any],
test_cache: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
test_cache: Array | None = None,
) -> tuple[np.ndarray, Array]:
"""Classification evaluation by training a sklearn classifier on the embeddings of the training set and evaluating on the embeddings of the test set.

Args:
Expand Down Expand Up @@ -104,6 +104,7 @@ def __call__( # type: ignore[override]
hf_subset=self.hf_subset,
**encode_kwargs,
)
test_cache = cast(Array, test_cache)

logger.info("Running - Fitting classifier...")
y_train = self.train_dataset[self.label_column_name]
Expand Down
Loading