Skip to content
8 changes: 5 additions & 3 deletions docs/overview/create_available_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Updates the available benchmarks markdown file."""

from pathlib import Path
from typing import cast
from typing import TYPE_CHECKING, cast

from prettify_list import pretty_long_list
from slugify import slugify_anchor

import mteb
from mteb.get_tasks import MTEBTasks

if TYPE_CHECKING:
from mteb.get_tasks import MTEBTasks

benchmark_entry = """
#### {benchmark_name}
Expand Down Expand Up @@ -38,7 +40,7 @@
def create_table(benchmark: mteb.Benchmark) -> str:
"""Create a markdown table of tasks in the benchmark."""
tasks = benchmark.tasks
tasks = cast(MTEBTasks, tasks)
tasks = cast("MTEBTasks", tasks)
df = tasks.to_dataframe(["name", "type", "modalities", "languages"])

# add links to task names:
Expand Down
7 changes: 6 additions & 1 deletion docs/overview/create_available_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Updates the available models markdown files."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

from prettify_list import pretty_long_list

import mteb
from mteb.models import ModelMeta

if TYPE_CHECKING:
from mteb.models import ModelMeta

model_entry = """
#### {model_name_w_link}
Expand Down
2 changes: 1 addition & 1 deletion docs/overview/create_available_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def task_category_to_string(category: str) -> str:


def create_aggregate_table(task: AbsTaskAggregate) -> str:
tasks = cast(MTEBTasks, MTEBTasks(task.metadata.tasks))
tasks = cast("MTEBTasks", MTEBTasks(task.metadata.tasks))
df = tasks.to_dataframe(["name", "type", "modalities", "languages"])
df["name"] = df.apply(
lambda row: (
Expand Down
25 changes: 16 additions & 9 deletions mteb/_create_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Callable
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

import torch
from datasets import Dataset, Image
from torch.utils.data import DataLoader, default_collate

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.types import (
BatchedInput,
Conversation,
ConversationTurn,
PromptType,
QueryDatasetType,
)
from mteb.types._encoder_io import CorpusInput, ImageInput, QueryInput, TextInput

if TYPE_CHECKING:
from collections.abc import Callable

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.types import (
BatchedInput,
Conversation,
QueryDatasetType,
)
from mteb.types._encoder_io import CorpusInput, ImageInput, QueryInput, TextInput

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,7 +135,7 @@ def _convert_conv_history_to_query(
conversation = row["text"]
# if it's a list of strings, just join them
if isinstance(conversation, list) and isinstance(conversation[0], str):
conversation_ = cast(list[str], conversation)
conversation_ = cast("list[str]", conversation)
conv_str = "; ".join(conversation_)
current_conversation = [
ConversationTurn(role="user", content=message) for message in conversation_
Expand Down Expand Up @@ -173,7 +180,7 @@ def _convert_conv_history_to_query(

row["text"] = conv_str
row["conversation"] = current_conversation
return cast(dict[str, str | list[ConversationTurn]], row)
return cast("dict[str, str | list[ConversationTurn]]", row)


def _create_dataloader_for_queries_conversation(
Expand Down
15 changes: 10 additions & 5 deletions mteb/_evaluators/any_sts_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from __future__ import annotations

import logging
from typing import TypedDict
from typing import TYPE_CHECKING, TypedDict

from datasets import Dataset
from sklearn.metrics.pairwise import (
paired_cosine_distances,
paired_euclidean_distances,
paired_manhattan_distances,
)

from mteb._create_dataloaders import create_dataloader
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.similarity_functions import compute_pairwise_similarity
from mteb.types import EncodeKwargs, PromptType

from .evaluator import Evaluator

if TYPE_CHECKING:
from datasets import Dataset

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import EncodeKwargs, PromptType

logger = logging.getLogger(__name__)


Expand Down
14 changes: 10 additions & 4 deletions mteb/_evaluators/clustering_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from datasets import Dataset
from sklearn import cluster

from mteb._create_dataloaders import create_dataloader
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import EncodeKwargs

from .evaluator import Evaluator

if TYPE_CHECKING:
from datasets import Dataset

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import EncodeKwargs

logger = logging.getLogger(__name__)


Expand Down
13 changes: 9 additions & 4 deletions mteb/_evaluators/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from typing import Any
from typing import TYPE_CHECKING, Any

from mteb.abstasks.abstask import _set_seed
from mteb.models import EncoderProtocol
from mteb.types import EncodeKwargs

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping

from mteb.models import EncoderProtocol
from mteb.types import EncodeKwargs


class Evaluator(ABC):
Expand Down
10 changes: 6 additions & 4 deletions mteb/_evaluators/image/imagetext_pairclassification_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import torch
Expand All @@ -14,13 +13,16 @@
)
from mteb._evaluators.evaluator import Evaluator
from mteb._requires_package import requires_image_dependencies
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models.models_protocols import EncoderProtocol
from mteb.types import EncodeKwargs

if TYPE_CHECKING:
from collections.abc import Sequence

from PIL.Image import Image

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models.models_protocols import EncoderProtocol
from mteb.types import EncodeKwargs


logger = logging.getLogger(__name__)

Expand Down
15 changes: 10 additions & 5 deletions mteb/_evaluators/pair_classification_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import logging
from typing import Any, TypedDict
from typing import TYPE_CHECKING, Any, TypedDict

import numpy as np
from datasets import Dataset
from sklearn.metrics.pairwise import (
paired_cosine_distances,
paired_euclidean_distances,
Expand All @@ -11,10 +12,14 @@

from mteb._create_dataloaders import _create_dataloader_from_texts, create_dataloader
from mteb._evaluators.evaluator import Evaluator
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.similarity_functions import compute_pairwise_similarity
from mteb.types import EncodeKwargs, PromptType

if TYPE_CHECKING:
from datasets import Dataset

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import EncodeKwargs, PromptType

logger = logging.getLogger(__name__)

Expand Down
32 changes: 19 additions & 13 deletions mteb/_evaluators/retrieval_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import logging
from collections.abc import Sequence
from __future__ import annotations

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import SearchProtocol
from mteb.types import (
CorpusDatasetType,
EncodeKwargs,
QueryDatasetType,
RelevantDocumentsType,
RetrievalEvaluationResult,
RetrievalOutputType,
TopRankedDocumentsType,
)
import logging
from typing import TYPE_CHECKING

from .evaluator import Evaluator
from .retrieval_metrics import (
calculate_retrieval_scores,
)

if TYPE_CHECKING:
from collections.abc import Sequence

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import SearchProtocol
from mteb.types import (
CorpusDatasetType,
EncodeKwargs,
QueryDatasetType,
RelevantDocumentsType,
RetrievalEvaluationResult,
RetrievalOutputType,
TopRankedDocumentsType,
)


logger = logging.getLogger(__name__)


Expand Down
12 changes: 9 additions & 3 deletions mteb/_evaluators/retrieval_metrics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import Mapping
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
import pytrec_eval
from packaging.version import Version
from sklearn.metrics import auc

from mteb.types import RelevantDocumentsType, RetrievalEvaluationResult
from mteb.types import RetrievalEvaluationResult

if TYPE_CHECKING:
from collections.abc import Mapping

from mteb.types import RelevantDocumentsType

logger = logging.getLogger(__name__)

Expand Down
24 changes: 14 additions & 10 deletions mteb/_evaluators/sklearn_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import logging
from typing import Any, Protocol, cast
from __future__ import annotations

import numpy as np
from datasets import Dataset
from torch.utils.data import DataLoader
from typing_extensions import Self
import logging
from typing import TYPE_CHECKING, Any, Protocol, cast

from mteb._create_dataloaders import create_dataloader
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import Array, BatchedInput, EncodeKwargs

from .evaluator import Evaluator

if TYPE_CHECKING:
import numpy as np
from datasets import Dataset
from torch.utils.data import DataLoader
from typing_extensions import Self

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import Array, BatchedInput, EncodeKwargs

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -104,7 +108,7 @@ def __call__( # type: ignore[override]
hf_subset=self.hf_subset,
**encode_kwargs,
)
test_cache = cast(Array, test_cache)
test_cache = cast("Array", test_cache)

logger.info("Running - Fitting classifier...")
y_train = self.train_dataset[self.label_column_name]
Expand Down
11 changes: 8 additions & 3 deletions mteb/_evaluators/text/bitext_mining_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import torch
from datasets import Dataset
from tqdm.auto import tqdm

from mteb._create_dataloaders import _create_dataloader_from_texts
from mteb._evaluators.evaluator import Evaluator
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import Array, EncodeKwargs

if TYPE_CHECKING:
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models import EncoderProtocol
from mteb.types import Array, EncodeKwargs

logger = logging.getLogger(__name__)

Expand Down
Loading