diff --git a/docs/overview/create_available_benchmarks.py b/docs/overview/create_available_benchmarks.py index 01bb873e6f..c6701e24c8 100644 --- a/docs/overview/create_available_benchmarks.py +++ b/docs/overview/create_available_benchmarks.py @@ -1,13 +1,15 @@ """Updates the available benchmarks markdown file.""" from pathlib import Path -from typing import cast +from typing import TYPE_CHECKING, cast from prettify_list import pretty_long_list from slugify import slugify_anchor import mteb -from mteb.get_tasks import MTEBTasks + +if TYPE_CHECKING: + from mteb.get_tasks import MTEBTasks benchmark_entry = """ #### {benchmark_name} @@ -38,7 +40,7 @@ def create_table(benchmark: mteb.Benchmark) -> str: """Create a markdown table of tasks in the benchmark.""" tasks = benchmark.tasks - tasks = cast(MTEBTasks, tasks) + tasks = cast("MTEBTasks", tasks) df = tasks.to_dataframe(["name", "type", "modalities", "languages"]) # add links to task names: diff --git a/docs/overview/create_available_models.py b/docs/overview/create_available_models.py index eb4b6a976f..4b64cbf981 100644 --- a/docs/overview/create_available_models.py +++ b/docs/overview/create_available_models.py @@ -1,11 +1,16 @@ """Updates the available models markdown files.""" +from __future__ import annotations + from pathlib import Path +from typing import TYPE_CHECKING from prettify_list import pretty_long_list import mteb -from mteb.models import ModelMeta + +if TYPE_CHECKING: + from mteb.models import ModelMeta model_entry = """ #### {model_name_w_link} diff --git a/docs/overview/create_available_tasks.py b/docs/overview/create_available_tasks.py index 444d2fac4f..5dde6450dd 100644 --- a/docs/overview/create_available_tasks.py +++ b/docs/overview/create_available_tasks.py @@ -83,7 +83,7 @@ def task_category_to_string(category: str) -> str: def create_aggregate_table(task: AbsTaskAggregate) -> str: - tasks = cast(MTEBTasks, MTEBTasks(task.metadata.tasks)) + tasks = cast("MTEBTasks", MTEBTasks(task.metadata.tasks)) df = tasks.to_dataframe(["name", "type", "modalities", "languages"]) df["name"] = df.apply( lambda row: ( diff --git a/mteb/_create_dataloaders.py b/mteb/_create_dataloaders.py index 6e3126ae76..fe2fdafd5e 100644 --- a/mteb/_create_dataloaders.py +++ b/mteb/_create_dataloaders.py @@ -1,21 +1,28 @@ +from __future__ import annotations + import logging import warnings -from collections.abc import Callable -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import torch from datasets import Dataset, Image from torch.utils.data import DataLoader, default_collate -from mteb.abstasks.task_metadata import TaskMetadata from mteb.types import ( - BatchedInput, - Conversation, ConversationTurn, PromptType, - QueryDatasetType, ) -from mteb.types._encoder_io import CorpusInput, ImageInput, QueryInput, TextInput + +if TYPE_CHECKING: + from collections.abc import Callable + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import ( + BatchedInput, + Conversation, + QueryDatasetType, + ) + from mteb.types._encoder_io import CorpusInput, ImageInput, QueryInput, TextInput logger = logging.getLogger(__name__) @@ -128,7 +135,7 @@ def _convert_conv_history_to_query( conversation = row["text"] # if it's a list of strings, just join them if isinstance(conversation, list) and isinstance(conversation[0], str): - conversation_ = cast(list[str], conversation) + conversation_ = cast("list[str]", conversation) conv_str = "; ".join(conversation_) current_conversation = [ ConversationTurn(role="user", content=message) for message in conversation_ @@ -173,7 +180,7 @@ def _convert_conv_history_to_query( row["text"] = conv_str row["conversation"] = current_conversation - return cast(dict[str, str | list[ConversationTurn]], row) + return cast("dict[str, str | list[ConversationTurn]]", row) def _create_dataloader_for_queries_conversation( diff --git a/mteb/_evaluators/any_sts_evaluator.py b/mteb/_evaluators/any_sts_evaluator.py index a626627010..baf21517cb 100644 --- a/mteb/_evaluators/any_sts_evaluator.py +++ b/mteb/_evaluators/any_sts_evaluator.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict -from datasets import Dataset from sklearn.metrics.pairwise import ( paired_cosine_distances, paired_euclidean_distances, @@ -9,13 +10,17 @@ ) from mteb._create_dataloaders import create_dataloader -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import EncoderProtocol from mteb.similarity_functions import compute_pairwise_similarity -from mteb.types import EncodeKwargs, PromptType from .evaluator import Evaluator +if TYPE_CHECKING: + from datasets import Dataset + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import EncoderProtocol + from mteb.types import EncodeKwargs, PromptType + logger = logging.getLogger(__name__) diff --git a/mteb/_evaluators/clustering_evaluator.py b/mteb/_evaluators/clustering_evaluator.py index 64b31d5fb8..5d1c2bbc6d 100644 --- a/mteb/_evaluators/clustering_evaluator.py +++ b/mteb/_evaluators/clustering_evaluator.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING -from datasets import Dataset from sklearn import cluster from mteb._create_dataloaders import create_dataloader -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import EncoderProtocol -from mteb.types import EncodeKwargs from .evaluator import Evaluator +if TYPE_CHECKING: + from datasets import Dataset + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import EncoderProtocol + from mteb.types import EncodeKwargs + logger = logging.getLogger(__name__) diff --git a/mteb/_evaluators/evaluator.py b/mteb/_evaluators/evaluator.py index b49557706b..74002ae0ad 100644 --- a/mteb/_evaluators/evaluator.py +++ b/mteb/_evaluators/evaluator.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping -from typing import Any +from typing import TYPE_CHECKING, Any from mteb.abstasks.abstask import _set_seed -from mteb.models import EncoderProtocol -from mteb.types import EncodeKwargs + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + from mteb.models import EncoderProtocol + from mteb.types import EncodeKwargs class Evaluator(ABC): diff --git a/mteb/_evaluators/image/imagetext_pairclassification_evaluator.py b/mteb/_evaluators/image/imagetext_pairclassification_evaluator.py index ace6e8f1cf..fc77eb824d 100644 --- a/mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +++ b/mteb/_evaluators/image/imagetext_pairclassification_evaluator.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -from collections.abc import Sequence from typing import TYPE_CHECKING, Any import torch @@ -14,13 +13,16 @@ ) from mteb._evaluators.evaluator import Evaluator from mteb._requires_package import requires_image_dependencies -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models.models_protocols import EncoderProtocol -from mteb.types import EncodeKwargs if TYPE_CHECKING: + from collections.abc import Sequence + from PIL.Image import Image + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models.models_protocols import EncoderProtocol + from mteb.types import EncodeKwargs + logger = logging.getLogger(__name__) diff --git a/mteb/_evaluators/pair_classification_evaluator.py b/mteb/_evaluators/pair_classification_evaluator.py index e0dbec00b1..48b2a72a1a 100644 --- a/mteb/_evaluators/pair_classification_evaluator.py +++ b/mteb/_evaluators/pair_classification_evaluator.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict import numpy as np -from datasets import Dataset from sklearn.metrics.pairwise import ( paired_cosine_distances, paired_euclidean_distances, @@ -11,10 +12,14 @@ from mteb._create_dataloaders import _create_dataloader_from_texts, create_dataloader from mteb._evaluators.evaluator import Evaluator -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import EncoderProtocol from mteb.similarity_functions import compute_pairwise_similarity -from mteb.types import EncodeKwargs, PromptType + +if TYPE_CHECKING: + from datasets import Dataset + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import EncoderProtocol + from mteb.types import EncodeKwargs, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/_evaluators/retrieval_evaluator.py b/mteb/_evaluators/retrieval_evaluator.py index 425ffaecd5..c2d269c81f 100644 --- a/mteb/_evaluators/retrieval_evaluator.py +++ b/mteb/_evaluators/retrieval_evaluator.py @@ -1,23 +1,29 @@ -import logging -from collections.abc import Sequence +from __future__ import annotations -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import SearchProtocol -from mteb.types import ( - CorpusDatasetType, - EncodeKwargs, - QueryDatasetType, - RelevantDocumentsType, - RetrievalEvaluationResult, - RetrievalOutputType, - TopRankedDocumentsType, -) +import logging +from typing import TYPE_CHECKING from .evaluator import Evaluator from .retrieval_metrics import ( calculate_retrieval_scores, ) +if TYPE_CHECKING: + from collections.abc import Sequence + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import SearchProtocol + from mteb.types import ( + CorpusDatasetType, + EncodeKwargs, + QueryDatasetType, + RelevantDocumentsType, + RetrievalEvaluationResult, + RetrievalOutputType, + TopRankedDocumentsType, + ) + + logger = logging.getLogger(__name__) diff --git a/mteb/_evaluators/retrieval_metrics.py b/mteb/_evaluators/retrieval_metrics.py index 1d1f2b51bb..895147ac03 100644 --- a/mteb/_evaluators/retrieval_metrics.py +++ b/mteb/_evaluators/retrieval_metrics.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging from collections import defaultdict -from collections.abc import Mapping -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -9,7 +10,12 @@ from packaging.version import Version from sklearn.metrics import auc -from mteb.types import RelevantDocumentsType, RetrievalEvaluationResult +from mteb.types import RetrievalEvaluationResult + +if TYPE_CHECKING: + from collections.abc import Mapping + + from mteb.types import RelevantDocumentsType logger = logging.getLogger(__name__) diff --git a/mteb/_evaluators/sklearn_evaluator.py b/mteb/_evaluators/sklearn_evaluator.py index aedb37df1f..8705dafde3 100644 --- a/mteb/_evaluators/sklearn_evaluator.py +++ b/mteb/_evaluators/sklearn_evaluator.py @@ -1,18 +1,22 @@ -import logging -from typing import Any, Protocol, cast +from __future__ import annotations -import numpy as np -from datasets import Dataset -from torch.utils.data import DataLoader -from typing_extensions import Self +import logging +from typing import TYPE_CHECKING, Any, Protocol, cast from mteb._create_dataloaders import create_dataloader -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import EncoderProtocol -from mteb.types import Array, BatchedInput, EncodeKwargs from .evaluator import Evaluator +if TYPE_CHECKING: + import numpy as np + from datasets import Dataset + from torch.utils.data import DataLoader + from typing_extensions import Self + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import EncoderProtocol + from mteb.types import Array, BatchedInput, EncodeKwargs + logger = logging.getLogger(__name__) @@ -104,7 +108,7 @@ def __call__( # type: ignore[override] hf_subset=self.hf_subset, **encode_kwargs, ) - test_cache = cast(Array, test_cache) + test_cache = cast("Array", test_cache) logger.info("Running - Fitting classifier...") y_train = self.train_dataset[self.label_column_name] diff --git a/mteb/_evaluators/text/bitext_mining_evaluator.py b/mteb/_evaluators/text/bitext_mining_evaluator.py index a2dc99e5c1..687326a743 100644 --- a/mteb/_evaluators/text/bitext_mining_evaluator.py +++ b/mteb/_evaluators/text/bitext_mining_evaluator.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING import torch from datasets import Dataset @@ -6,9 +9,11 @@ from mteb._create_dataloaders import _create_dataloader_from_texts from mteb._evaluators.evaluator import Evaluator -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import EncoderProtocol -from mteb.types import Array, EncodeKwargs + +if TYPE_CHECKING: + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import EncoderProtocol + from mteb.types import Array, EncodeKwargs logger = logging.getLogger(__name__) diff --git a/mteb/_evaluators/text/summarization_evaluator.py b/mteb/_evaluators/text/summarization_evaluator.py index a9b16340a4..0ccb3ebc65 100644 --- a/mteb/_evaluators/text/summarization_evaluator.py +++ b/mteb/_evaluators/text/summarization_evaluator.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import logging import sys -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict import numpy as np import torch @@ -9,10 +11,12 @@ from mteb._create_dataloaders import _create_dataloader_from_texts from mteb._evaluators.evaluator import Evaluator -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import EncoderProtocol from mteb.similarity_functions import cos_sim, dot_score -from mteb.types import EncodeKwargs + +if TYPE_CHECKING: + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import EncoderProtocol + from mteb.types import EncodeKwargs # if later than python 3.13 use typing module if sys.version_info >= (3, 13): diff --git a/mteb/_evaluators/zeroshot_classification_evaluator.py b/mteb/_evaluators/zeroshot_classification_evaluator.py index 502549cefd..1baf533779 100644 --- a/mteb/_evaluators/zeroshot_classification_evaluator.py +++ b/mteb/_evaluators/zeroshot_classification_evaluator.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING from datasets import Dataset @@ -6,13 +9,17 @@ _create_dataloader_from_texts, create_dataloader, ) -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models import EncoderProtocol from mteb.similarity_functions import similarity -from mteb.types import Array, EncodeKwargs from .evaluator import Evaluator +if TYPE_CHECKING: + from datasets import Dataset + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import EncoderProtocol + from mteb.types import Array, EncodeKwargs + logger = logging.getLogger(__name__) diff --git a/mteb/_helpful_enum.py b/mteb/_helpful_enum.py index d80073b90b..d9ab668959 100644 --- a/mteb/_helpful_enum.py +++ b/mteb/_helpful_enum.py @@ -1,6 +1,10 @@ +from __future__ import annotations + from enum import Enum +from typing import TYPE_CHECKING -from typing_extensions import Self +if TYPE_CHECKING: + from typing_extensions import Self class HelpfulStrEnum(str, Enum): diff --git a/mteb/abstasks/_data_filter/filters.py b/mteb/abstasks/_data_filter/filters.py index 16ed5e8d97..96f9d270f5 100644 --- a/mteb/abstasks/_data_filter/filters.py +++ b/mteb/abstasks/_data_filter/filters.py @@ -1,12 +1,18 @@ """Simplified version of https://gist.github.com/AlexeyVatolin/ea3adc21aa7a767603ff393b22085adc from https://github.com/embeddings-benchmark/mteb/pull/2900""" +from __future__ import annotations + import logging +from typing import TYPE_CHECKING import datasets import pandas as pd -from datasets import Dataset, DatasetDict +from datasets import DatasetDict + +if TYPE_CHECKING: + from datasets import Dataset -from mteb import TaskMetadata + from mteb import TaskMetadata logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/_data_filter/task_pipelines.py b/mteb/abstasks/_data_filter/task_pipelines.py index c376edc546..563030d90f 100644 --- a/mteb/abstasks/_data_filter/task_pipelines.py +++ b/mteb/abstasks/_data_filter/task_pipelines.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING from datasets import DatasetDict -from mteb import TaskMetadata -from mteb.abstasks import AbsTaskClassification from mteb.abstasks._data_filter.filters import ( deduplicate, filter_empty, @@ -13,6 +14,10 @@ split_train_test, ) +if TYPE_CHECKING: + from mteb import TaskMetadata + from mteb.abstasks import AbsTaskClassification + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/_statistics_calculation.py b/mteb/abstasks/_statistics_calculation.py index 598d50af71..f5d0029de0 100644 --- a/mteb/abstasks/_statistics_calculation.py +++ b/mteb/abstasks/_statistics_calculation.py @@ -2,10 +2,8 @@ import hashlib from collections import Counter -from collections.abc import Mapping from typing import TYPE_CHECKING, cast -from mteb.types import TopRankedDocumentsType from mteb.types.statistics import ( ImageStatistics, LabelStatistics, @@ -16,8 +14,12 @@ ) if TYPE_CHECKING: + from collections.abc import Mapping + from PIL import Image + from mteb.types import TopRankedDocumentsType + def calculate_text_statistics(texts: list[str]) -> TextStatistics: """Calculate descriptive statistics for a list of texts. @@ -87,13 +89,13 @@ def calculate_label_statistics(labels: list[int | list[int]]) -> LabelStatistics if not isinstance(labels[0], list): # single label classification - single_label = cast(list[int], labels) + single_label = cast("list[int]", labels) label_len = [1] * len(single_label) total_label_len = len(single_label) total_labels.extend(single_label) elif isinstance(labels[0], list): # multilabel classification - multilabel_labels = cast(list[list[int]], labels) + multilabel_labels = cast("list[list[int]]", labels) label_len = [len(l) for l in multilabel_labels] total_label_len = sum(label_len) for l in multilabel_labels: diff --git a/mteb/abstasks/abstask.py b/mteb/abstasks/abstask.py index 84112d35ea..6b21c77b68 100644 --- a/mteb/abstasks/abstask.py +++ b/mteb/abstasks/abstask.py @@ -1,30 +1,38 @@ +from __future__ import annotations + import json import logging import warnings from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from copy import copy from pathlib import Path -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np from datasets import ClassLabel, Dataset, DatasetDict, load_dataset from sklearn.preprocessing import MultiLabelBinarizer from tqdm.auto import tqdm -from typing_extensions import Self from mteb._set_seed import _set_seed -from mteb.abstasks.task_metadata import TaskMetadata from mteb.languages import LanguageScripts from mteb.models import ( CrossEncoderProtocol, EncoderProtocol, - MTEBModels, SearchProtocol, ) -from mteb.types import HFSubset, Modalities, ScoresDict -from mteb.types._encoder_io import EncodeKwargs -from mteb.types.statistics import DescriptiveStatistics, SplitDescriptiveStatistics + +if TYPE_CHECKING: + from collections.abc import Mapping + + from typing_extensions import Self + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models import ( + MTEBModels, + ) + from mteb.types import EncodeKwargs, HFSubset, Modalities, ScoresDict + from mteb.types.statistics import DescriptiveStatistics, SplitDescriptiveStatistics logger = logging.getLogger(__name__) @@ -163,7 +171,7 @@ def evaluate( if not self.data_loaded: self.load_data() - self.dataset = cast(dict[HFSubset, DatasetDict], self.dataset) + self.dataset = cast("dict[HFSubset, DatasetDict]", self.dataset) scores = {} if self.hf_subsets is None: diff --git a/mteb/abstasks/aggregate_task_metadata.py b/mteb/abstasks/aggregate_task_metadata.py index 560fb7c60f..7e26d39a0b 100644 --- a/mteb/abstasks/aggregate_task_metadata.py +++ b/mteb/abstasks/aggregate_task_metadata.py @@ -1,28 +1,39 @@ +from __future__ import annotations + import logging from datetime import datetime +from typing import TYPE_CHECKING from pydantic import ConfigDict, Field, model_validator -from typing_extensions import Self from mteb.types import ( - ISOLanguageScript, Languages, - Licenses, - Modalities, - StrDate, ) from .abstask import AbsTask from .task_metadata import ( - AnnotatorType, MetadataDatasetDict, - SampleCreationMethod, - TaskDomain, TaskMetadata, - TaskSubtype, TaskType, ) +if TYPE_CHECKING: + from typing_extensions import Self + + from mteb.types import ( + ISOLanguageScript, + Licenses, + Modalities, + StrDate, + ) + + from .task_metadata import ( + AnnotatorType, + SampleCreationMethod, + TaskDomain, + TaskSubtype, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/aggregated_task.py b/mteb/abstasks/aggregated_task.py index 3fedeb6642..b55ad77a8d 100644 --- a/mteb/abstasks/aggregated_task.py +++ b/mteb/abstasks/aggregated_task.py @@ -1,19 +1,26 @@ +from __future__ import annotations + import logging import warnings -from collections.abc import Mapping -from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -from datasets import Dataset, DatasetDict -from mteb.models.models_protocols import MTEBModels from mteb.results.task_result import TaskResult -from mteb.types import EncodeKwargs, HFSubset, ScoresDict -from mteb.types.statistics import DescriptiveStatistics from .abstask import AbsTask -from .aggregate_task_metadata import AggregateTaskMetadata + +if TYPE_CHECKING: + from collections.abc import Mapping + from pathlib import Path + + from datasets import Dataset, DatasetDict + + from mteb.models.models_protocols import MTEBModels + from mteb.types import EncodeKwargs, HFSubset, ScoresDict + from mteb.types.statistics import DescriptiveStatistics + + from .aggregate_task_metadata import AggregateTaskMetadata logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/classification.py b/mteb/abstasks/classification.py index 6538123981..20b729b251 100644 --- a/mteb/abstasks/classification.py +++ b/mteb/abstasks/classification.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging from collections import defaultdict -from pathlib import Path -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict import numpy as np from datasets import Dataset, DatasetDict @@ -16,12 +17,8 @@ from mteb._evaluators.sklearn_evaluator import SklearnEvaluator, SklearnModelProtocol from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs, HFSubset, ScoresDict from mteb.types.statistics import ( - ImageStatistics, - LabelStatistics, SplitDescriptiveStatistics, - TextStatistics, ) from ._statistics_calculation import ( @@ -31,6 +28,18 @@ ) from .abstask import AbsTask +if TYPE_CHECKING: + from pathlib import Path + + from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol + from mteb.models import MTEBModels + from mteb.types import EncodeKwargs, HFSubset, ScoresDict + from mteb.types.statistics import ( + ImageStatistics, + LabelStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/clustering.py b/mteb/abstasks/clustering.py index 3346876ed4..87d2161ac9 100644 --- a/mteb/abstasks/clustering.py +++ b/mteb/abstasks/clustering.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import itertools import logging import random from collections import defaultdict -from pathlib import Path -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np from datasets import Dataset, DatasetDict @@ -11,13 +12,10 @@ from sklearn.metrics.cluster import v_measure_score from mteb._create_dataloaders import create_dataloader -from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import Array, EncodeKwargs, HFSubset, ScoresDict +from mteb.models import EncoderProtocol +from mteb.types import Array, HFSubset from mteb.types.statistics import ( - ImageStatistics, - LabelStatistics, SplitDescriptiveStatistics, - TextStatistics, ) from ._statistics_calculation import ( @@ -27,6 +25,17 @@ ) from .abstask import AbsTask +if TYPE_CHECKING: + from pathlib import Path + + from mteb.models import MTEBModels + from mteb.types import Array, EncodeKwargs, ScoresDict + from mteb.types.statistics import ( + ImageStatistics, + LabelStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) @@ -186,7 +195,7 @@ def _evaluate_subset( self.max_fraction_of_documents_to_embed * len(data_split) ) else: - max_documents_to_embed = cast(int, self.max_document_to_embed) + max_documents_to_embed = cast("int", self.max_document_to_embed) max_documents_to_embed = min(len(data_split), max_documents_to_embed) example_indices = self.rng_state.sample( diff --git a/mteb/abstasks/clustering_legacy.py b/mteb/abstasks/clustering_legacy.py index e7928cf1a3..8a7ee7dfc1 100644 --- a/mteb/abstasks/clustering_legacy.py +++ b/mteb/abstasks/clustering_legacy.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import logging -from pathlib import Path -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict import numpy as np from datasets import Dataset @@ -9,12 +10,8 @@ from mteb._evaluators import ClusteringEvaluator from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs, ScoresDict from mteb.types.statistics import ( - ImageStatistics, - LabelStatistics, SplitDescriptiveStatistics, - TextStatistics, ) from ._statistics_calculation import ( @@ -24,6 +21,17 @@ ) from .abstask import AbsTask +if TYPE_CHECKING: + from pathlib import Path + + from mteb.models import MTEBModels + from mteb.types import EncodeKwargs, ScoresDict + from mteb.types.statistics import ( + ImageStatistics, + LabelStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/image/image_text_pair_classification.py b/mteb/abstasks/image/image_text_pair_classification.py index c586b846ac..f469f5e13e 100644 --- a/mteb/abstasks/image/image_text_pair_classification.py +++ b/mteb/abstasks/image/image_text_pair_classification.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import logging from collections.abc import Sequence -from pathlib import Path -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict import torch -from datasets import Dataset, concatenate_datasets +from datasets import concatenate_datasets from mteb._evaluators import ImageTextPairClassificationEvaluator from mteb.abstasks._statistics_calculation import ( @@ -12,14 +13,23 @@ calculate_text_statistics, ) from mteb.abstasks.abstask import AbsTask -from mteb.models.models_protocols import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs +from mteb.models.models_protocols import EncoderProtocol from mteb.types.statistics import ( - ImageStatistics, SplitDescriptiveStatistics, - TextStatistics, ) +if TYPE_CHECKING: + from pathlib import Path + + from datasets import Dataset + + from mteb.models.models_protocols import MTEBModels + from mteb.types import EncodeKwargs + from mteb.types.statistics import ( + ImageStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/multilabel_classification.py b/mteb/abstasks/multilabel_classification.py index 9347763582..3ecfd25452 100644 --- a/mteb/abstasks/multilabel_classification.py +++ b/mteb/abstasks/multilabel_classification.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import itertools import logging from collections import defaultdict -from pathlib import Path -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict import numpy as np from datasets import DatasetDict @@ -15,12 +16,17 @@ from mteb._create_dataloaders import create_dataloader from mteb._evaluators.classification_metrics import hamming_score -from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol -from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import Array, EncodeKwargs +from mteb.models import EncoderProtocol from .classification import AbsTaskClassification +if TYPE_CHECKING: + from pathlib import Path + + from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol + from mteb.models import MTEBModels + from mteb.types import Array, EncodeKwargs + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/pair_classification.py b/mteb/abstasks/pair_classification.py index 85ef7fce45..21ecb9564c 100644 --- a/mteb/abstasks/pair_classification.py +++ b/mteb/abstasks/pair_classification.py @@ -1,16 +1,15 @@ +from __future__ import annotations + import hashlib import logging from collections import defaultdict -from pathlib import Path +from typing import TYPE_CHECKING import numpy as np from datasets import Dataset from sklearn.metrics import average_precision_score from mteb._evaluators import PairClassificationEvaluator -from mteb._evaluators.pair_classification_evaluator import ( - PairClassificationDistances, -) from mteb.abstasks._statistics_calculation import ( calculate_image_statistics, calculate_label_statistics, @@ -18,15 +17,26 @@ ) from mteb.abstasks.abstask import AbsTask from mteb.models.model_meta import ScoringFunction -from mteb.models.models_protocols import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs, PromptType +from mteb.models.models_protocols import EncoderProtocol from mteb.types.statistics import ( - ImageStatistics, - LabelStatistics, SplitDescriptiveStatistics, - TextStatistics, ) +if TYPE_CHECKING: + from pathlib import Path + + from mteb._evaluators.pair_classification_evaluator import ( + PairClassificationDistances, + ) + from mteb.models.models_protocols import MTEBModels + from mteb.types import EncodeKwargs, PromptType + from mteb.types.statistics import ( + ImageStatistics, + LabelStatistics, + TextStatistics, + ) + + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/regression.py b/mteb/abstasks/regression.py index bebdce8d86..ae697133bc 100644 --- a/mteb/abstasks/regression.py +++ b/mteb/abstasks/regression.py @@ -1,29 +1,37 @@ +from __future__ import annotations + import logging -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict import datasets import numpy as np import pandas as pd -from datasets import Dataset from scipy.stats import kendalltau from sklearn.linear_model import LinearRegression from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score -from mteb._evaluators.sklearn_evaluator import SklearnEvaluator, SklearnModelProtocol +from mteb._evaluators.sklearn_evaluator import SklearnEvaluator from mteb.abstasks._statistics_calculation import ( calculate_image_statistics, calculate_score_statistics, calculate_text_statistics, ) from mteb.types.statistics import ( - ImageStatistics, - ScoreStatistics, SplitDescriptiveStatistics, - TextStatistics, ) from .classification import AbsTaskClassification +if TYPE_CHECKING: + from datasets import Dataset + + from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol + from mteb.types.statistics import ( + ImageStatistics, + ScoreStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/retrieval.py b/mteb/abstasks/retrieval.py index 7492db12c5..79ac12f0be 100644 --- a/mteb/abstasks/retrieval.py +++ b/mteb/abstasks/retrieval.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import json import logging from collections import defaultdict -from collections.abc import Callable, Mapping, Sequence from pathlib import Path from time import time -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal from datasets import Dataset, DatasetDict, concatenate_datasets -from typing_extensions import Self from mteb._create_dataloaders import ( _combine_queries_with_instruction_text, @@ -19,25 +19,12 @@ from mteb.models import ( CrossEncoderProtocol, EncoderProtocol, - MTEBModels, SearchCrossEncoderWrapper, SearchEncoderWrapper, SearchProtocol, ) -from mteb.types import ( - EncodeKwargs, - HFSubset, - QueryDatasetType, - RelevantDocumentsType, - RetrievalOutputType, - ScoresDict, -) from mteb.types.statistics import ( - ImageStatistics, - RelevantDocsStatistics, SplitDescriptiveStatistics, - TextStatistics, - TopRankedStatistics, ) from ._statistics_calculation import ( @@ -53,6 +40,30 @@ _combine_queries_with_instructions_datasets, ) +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + + from typing_extensions import Self + + from mteb.models import ( + MTEBModels, + ) + from mteb.types import ( + EncodeKwargs, + HFSubset, + QueryDatasetType, + RelevantDocumentsType, + RetrievalOutputType, + ScoresDict, + ) + from mteb.types.statistics import ( + ImageStatistics, + RelevantDocsStatistics, + TextStatistics, + TopRankedStatistics, + ) + + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/retrieval_dataset_loaders.py b/mteb/abstasks/retrieval_dataset_loaders.py index 6c25dae129..57842a64e4 100644 --- a/mteb/abstasks/retrieval_dataset_loaders.py +++ b/mteb/abstasks/retrieval_dataset_loaders.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict from datasets import ( Dataset, @@ -11,13 +13,14 @@ load_dataset, ) -from mteb.types import ( - CorpusDatasetType, - InstructionDatasetType, - QueryDatasetType, - RelevantDocumentsType, - TopRankedDocumentsType, -) +if TYPE_CHECKING: + from mteb.types import ( + CorpusDatasetType, + InstructionDatasetType, + QueryDatasetType, + RelevantDocumentsType, + TopRankedDocumentsType, + ) logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/sts.py b/mteb/abstasks/sts.py index a8b84cf9d5..5af55209e3 100644 --- a/mteb/abstasks/sts.py +++ b/mteb/abstasks/sts.py @@ -1,19 +1,14 @@ +from __future__ import annotations + import logging -from pathlib import Path -from typing import Any, TypedDict, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast -from datasets import Dataset from scipy.stats import pearsonr, spearmanr from mteb._evaluators import AnySTSEvaluator -from mteb._evaluators.any_sts_evaluator import STSEvaluatorScores -from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs, PromptType +from mteb.models import EncoderProtocol from mteb.types.statistics import ( - ImageStatistics, - ScoreStatistics, SplitDescriptiveStatistics, - TextStatistics, ) from ._statistics_calculation import ( @@ -23,6 +18,20 @@ ) from .abstask import AbsTask +if TYPE_CHECKING: + from pathlib import Path + + from datasets import Dataset + + from mteb._evaluators.any_sts_evaluator import STSEvaluatorScores + from mteb.models import MTEBModels + from mteb.types import EncodeKwargs, PromptType + from mteb.types.statistics import ( + ImageStatistics, + ScoreStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) @@ -182,7 +191,7 @@ def _calculate_descriptive_statistics_from_split( self, split: str, hf_subset: str | None = None, compute_overall: bool = False ) -> AnySTSDescriptiveStatistics: first_column, second_column = self.column_names - self.dataset = cast(dict[str, dict[str, Dataset]], self.dataset) + self.dataset = cast("dict[str, dict[str, Dataset]]", self.dataset) if hf_subset: sentence1 = self.dataset[hf_subset][split][first_column] diff --git a/mteb/abstasks/task_metadata.py b/mteb/abstasks/task_metadata.py index e1d0c91b65..d5a7dd8620 100644 --- a/mteb/abstasks/task_metadata.py +++ b/mteb/abstasks/task_metadata.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import json import logging from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast from huggingface_hub import ( - CardData, DatasetCard, DatasetCardData, constants, @@ -17,13 +18,11 @@ ConfigDict, field_validator, ) -from typing_extensions import Required, TypedDict +from typing_extensions import Required, TypedDict # noqa: TC002 import mteb from mteb.languages import check_language_code from mteb.types import ( - HFSubset, - ISOLanguageScript, Languages, Licenses, Modalities, @@ -31,7 +30,17 @@ StrDate, StrURL, ) -from mteb.types.statistics import DescriptiveStatistics + +if TYPE_CHECKING: + from huggingface_hub import ( + CardData, + ) + + from mteb.types import ( + HFSubset, + ISOLanguageScript, + ) + from mteb.types.statistics import DescriptiveStatistics logger = logging.getLogger(__name__) @@ -368,7 +377,7 @@ def hf_subsets_to_langscripts(self) -> dict[HFSubset, list[ISOLanguageScript]]: """Return a dictionary mapping huggingface subsets to languages.""" if isinstance(self.eval_langs, dict): return self.eval_langs - return {"default": cast(list[str], self.eval_langs)} + return {"default": cast("list[str]", self.eval_langs)} @property def intext_citation(self, include_cite: bool = True) -> str: @@ -697,7 +706,7 @@ def _hf_languages(self) -> list[str]: for val in self.eval_langs.values(): languages.extend(val) else: - languages = cast(list[str], self.eval_langs) + languages = cast("list[str]", self.eval_langs) # value "python" is not valid. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters), # or a special value like "code", "multilingual". readme_langs = [] diff --git a/mteb/abstasks/text/bitext_mining.py b/mteb/abstasks/text/bitext_mining.py index 8950efce80..3b37094585 100644 --- a/mteb/abstasks/text/bitext_mining.py +++ b/mteb/abstasks/text/bitext_mining.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging from collections import defaultdict -from pathlib import Path -from typing import Any, ClassVar, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from datasets import Dataset, DatasetDict from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score @@ -9,9 +10,15 @@ from mteb._evaluators import BitextMiningEvaluator from mteb.abstasks._statistics_calculation import calculate_text_statistics from mteb.abstasks.abstask import AbsTask -from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs, HFSubset, ScoresDict -from mteb.types.statistics import SplitDescriptiveStatistics, TextStatistics +from mteb.models import EncoderProtocol +from mteb.types.statistics import SplitDescriptiveStatistics + +if TYPE_CHECKING: + from pathlib import Path + + from mteb.models import MTEBModels + from mteb.types import EncodeKwargs, HFSubset, ScoresDict + from mteb.types.statistics import TextStatistics logger = logging.getLogger(__name__) @@ -90,7 +97,7 @@ def evaluate( if subsets_to_run is not None: hf_subsets = [s for s in hf_subsets if s in subsets_to_run] - encoder_model = cast(EncoderProtocol, model) + encoder_model = cast("EncoderProtocol", model) if self.dataset is None: raise ValueError("Dataset is not loaded.") @@ -127,7 +134,7 @@ def evaluate( **kwargs, ) - return cast(dict[HFSubset, ScoresDict], scores) + return cast("dict[HFSubset, ScoresDict]", scores) def _get_pairs(self, parallel: bool) -> list[tuple[str, str]]: pairs = self._DEFAULT_PAIR diff --git a/mteb/abstasks/text/summarization.py b/mteb/abstasks/text/summarization.py index 9d2fddae72..0d98caf4e4 100644 --- a/mteb/abstasks/text/summarization.py +++ b/mteb/abstasks/text/summarization.py @@ -1,24 +1,34 @@ +from __future__ import annotations + import logging -from pathlib import Path +from typing import TYPE_CHECKING import numpy as np -from datasets import Dataset from mteb._evaluators import SummarizationEvaluator -from mteb._evaluators.text.summarization_evaluator import SummarizationMetrics from mteb.abstasks._statistics_calculation import ( calculate_score_statistics, calculate_text_statistics, ) from mteb.abstasks.abstask import AbsTask -from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs +from mteb.models import EncoderProtocol from mteb.types.statistics import ( - ScoreStatistics, SplitDescriptiveStatistics, - TextStatistics, ) +if TYPE_CHECKING: + from pathlib import Path + + from datasets import Dataset + + from mteb._evaluators.text.summarization_evaluator import SummarizationMetrics + from mteb.models import MTEBModels + from mteb.types import EncodeKwargs + from mteb.types.statistics import ( + ScoreStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/abstasks/zeroshot_classification.py b/mteb/abstasks/zeroshot_classification.py index c48f1c3e3e..3ccccac75c 100644 --- a/mteb/abstasks/zeroshot_classification.py +++ b/mteb/abstasks/zeroshot_classification.py @@ -1,19 +1,16 @@ +from __future__ import annotations + import logging -from pathlib import Path -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict import torch from datasets import Dataset from sklearn import metrics from mteb._evaluators import ZeroShotClassificationEvaluator -from mteb.models import EncoderProtocol, MTEBModels -from mteb.types import EncodeKwargs +from mteb.models import EncoderProtocol from mteb.types.statistics import ( - ImageStatistics, - LabelStatistics, SplitDescriptiveStatistics, - TextStatistics, ) from ._statistics_calculation import ( @@ -23,6 +20,17 @@ ) from .abstask import AbsTask +if TYPE_CHECKING: + from pathlib import Path + + from mteb.models import MTEBModels + from mteb.types import EncodeKwargs + from mteb.types.statistics import ( + ImageStatistics, + LabelStatistics, + TextStatistics, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/benchmarks/_create_table.py b/mteb/benchmarks/_create_table.py index 23e4296339..a98b21b93e 100644 --- a/mteb/benchmarks/_create_table.py +++ b/mteb/benchmarks/_create_table.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import re from collections import defaultdict -from typing import Literal +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd import mteb from mteb.get_tasks import get_task, get_tasks -from mteb.results.benchmark_results import BenchmarkResults + +if TYPE_CHECKING: + from mteb.results.benchmark_results import BenchmarkResults def _borda_count(scores: pd.Series) -> pd.Series: diff --git a/mteb/cache.py b/mteb/cache.py index 2cfb327741..7c94ad7c7b 100644 --- a/mteb/cache.py +++ b/mteb/cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import gzip import io import json @@ -7,9 +9,8 @@ import subprocess import warnings from collections import defaultdict -from collections.abc import Iterable, Sequence from pathlib import Path -from typing import cast +from typing import TYPE_CHECKING, cast import requests @@ -18,7 +19,11 @@ from mteb.benchmarks.benchmark import Benchmark from mteb.models import ModelMeta from mteb.results import BenchmarkResults, ModelResult, TaskResult -from mteb.types import ModelName, Revision + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from mteb.types import ModelName, Revision logger = logging.getLogger(__name__) @@ -583,7 +588,7 @@ def _filter_paths_by_model_and_revision( first_model = next(iter(models)) if isinstance(first_model, ModelMeta): - models = cast(Iterable[ModelMeta], models) + models = cast("Iterable[ModelMeta]", models) name_and_revision = { (m.model_name_as_path(), m.revision or "no_revision_available") for m in models @@ -594,7 +599,7 @@ def _filter_paths_by_model_and_revision( if (p.parent.parent.name, p.parent.name) in name_and_revision ] - str_models = cast(Sequence[str], models) + str_models = cast("Sequence[str]", models) model_names = {m.replace("/", "__").replace(" ", "_") for m in str_models} return [p for p in paths if p.parent.parent.name in model_names] diff --git a/mteb/cli/_display_tasks.py b/mteb/cli/_display_tasks.py index cda4f36a00..a51fef207f 100644 --- a/mteb/cli/_display_tasks.py +++ b/mteb/cli/_display_tasks.py @@ -1,9 +1,15 @@ -from collections.abc import Iterable, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING -from mteb.abstasks import AbsTask -from mteb.benchmarks import Benchmark from mteb.get_tasks import MTEBTasks +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from mteb.abstasks import AbsTask + from mteb.benchmarks import Benchmark + def _display_benchmarks(benchmarks: Sequence[Benchmark]) -> None: """Get all benchmarks available in the MTEB.""" diff --git a/mteb/cli/build_cli.py b/mteb/cli/build_cli.py index d87614d049..db0befc7bf 100644 --- a/mteb/cli/build_cli.py +++ b/mteb/cli/build_cli.py @@ -3,17 +3,20 @@ import os import warnings from pathlib import Path +from typing import TYPE_CHECKING import torch from rich.logging import RichHandler import mteb -from mteb.abstasks.abstask import AbsTask from mteb.cache import ResultCache from mteb.cli._display_tasks import _display_benchmarks, _display_tasks from mteb.cli.generate_model_card import generate_model_card from mteb.evaluate import OverwriteStrategy -from mteb.types._encoder_io import EncodeKwargs + +if TYPE_CHECKING: + from mteb.abstasks.abstask import AbsTask + from mteb.types import EncodeKwargs logger = logging.getLogger(__name__) diff --git a/mteb/cli/generate_model_card.py b/mteb/cli/generate_model_card.py index eba490566d..a435a1253b 100644 --- a/mteb/cli/generate_model_card.py +++ b/mteb/cli/generate_model_card.py @@ -1,14 +1,21 @@ +from __future__ import annotations + import logging import warnings -from collections.abc import Sequence from pathlib import Path +from typing import TYPE_CHECKING from huggingface_hub import ModelCard, ModelCardData, repo_exists from mteb.abstasks.abstask import AbsTask -from mteb.benchmarks.benchmark import Benchmark from mteb.cache import ResultCache +if TYPE_CHECKING: + from collections.abc import Sequence + + from mteb.abstasks.abstask import AbsTask + from mteb.benchmarks.benchmark import Benchmark + logger = logging.getLogger(__name__) diff --git a/mteb/deprecated_evaluator.py b/mteb/deprecated_evaluator.py index 2ed2291c9e..14f0db9fda 100644 --- a/mteb/deprecated_evaluator.py +++ b/mteb/deprecated_evaluator.py @@ -6,7 +6,6 @@ import sys import traceback import warnings -from collections.abc import Iterable, Sequence from copy import deepcopy from datetime import datetime from itertools import chain @@ -18,26 +17,31 @@ import mteb from mteb.abstasks import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate -from mteb.abstasks.task_metadata import TaskCategory, TaskType from mteb.benchmarks import Benchmark from mteb.models import ( CrossEncoderWrapper, ModelMeta, - MTEBModels, SentenceTransformerEncoderWrapper, ) from mteb.results import TaskResult -from mteb.types import EncodeKwargs, ScoresDict + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from sentence_transformers import CrossEncoder, SentenceTransformer + + from mteb.abstasks.aggregated_task import AbsTaskAggregate + from mteb.abstasks.task_metadata import TaskCategory, TaskType + from mteb.models import ( + MTEBModels, + ) + from mteb.types import EncodeKwargs, ScoresDict if sys.version_info >= (3, 13): from warnings import deprecated else: from typing_extensions import deprecated -if TYPE_CHECKING: - from sentence_transformers import CrossEncoder, SentenceTransformer - logger = logging.getLogger(__name__) @@ -66,9 +70,9 @@ def __init__( """ if isinstance(next(iter(tasks)), Benchmark): self.benchmarks = tasks - self.tasks = list(chain.from_iterable(cast(Iterable[Benchmark], tasks))) + self.tasks = list(chain.from_iterable(cast("Iterable[Benchmark]", tasks))) elif isinstance(next(iter(tasks)), AbsTask): - self.tasks = list(cast(Iterable[AbsTask], tasks)) + self.tasks = list(cast("Iterable[AbsTask]", tasks)) self.err_logs_path = Path(err_logs_path) self._last_evaluated_splits: dict[str, list[str]] = {} @@ -313,7 +317,7 @@ def run( elif isinstance(model, CrossEncoder): mteb_model = CrossEncoderWrapper(model) else: - mteb_model = cast(MTEBModels, model) + mteb_model = cast("MTEBModels", model) meta = self.create_model_meta(mteb_model) output_path = self._create_output_folder(meta, output_folder) @@ -346,7 +350,7 @@ def run( ) if task.is_aggregate: - aggregated_task = cast(AbsTaskAggregate, task) + aggregated_task = cast("AbsTaskAggregate", task) self_ = MTEB(tasks=aggregated_task.metadata.tasks) aggregated_task_results = self_.run( mteb_model, diff --git a/mteb/evaluate.py b/mteb/evaluate.py index 0b24b37074..2fee6a9434 100644 --- a/mteb/evaluate.py +++ b/mteb/evaluate.py @@ -2,7 +2,6 @@ import logging import warnings -from collections.abc import Iterable from pathlib import Path from time import time from typing import TYPE_CHECKING, cast @@ -17,22 +16,25 @@ from mteb.benchmarks.benchmark import Benchmark from mteb.cache import ResultCache from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import ( - MTEBModels, -) from mteb.models.sentence_transformer_wrapper import ( CrossEncoderWrapper, SentenceTransformerEncoderWrapper, ) from mteb.results import ModelResult, TaskResult from mteb.results.task_result import TaskError -from mteb.types import HFSubset, PromptType, SplitName -from mteb.types._encoder_io import EncodeKwargs -from mteb.types._metadata import ModelName, Revision +from mteb.types import PromptType if TYPE_CHECKING: + from collections.abc import Iterable + from sentence_transformers import CrossEncoder, SentenceTransformer + from mteb.models.models_protocols import ( + MTEBModels, + ) + from mteb.types import EncodeKwargs, HFSubset, SplitName + from mteb.types._metadata import ModelName, Revision + logger = logging.getLogger(__name__) @@ -69,13 +71,13 @@ def _sanitize_model( meta = getattr(model, "mteb_model_meta") if not isinstance(meta, ModelMeta): meta = ModelMeta._from_hub(None) - wrapped_model = cast(MTEBModels | ModelMeta, model) + wrapped_model = cast("MTEBModels | ModelMeta", model) else: meta = ModelMeta._from_hub(None) if not isinstance(model, ModelMeta) else model wrapped_model = meta - model_name = cast(str, meta.name) - model_revision = cast(str, meta.revision) + model_name = cast("str", meta.name) + model_revision = cast("str", meta.revision) return wrapped_model, meta, model_name, model_revision @@ -202,10 +204,10 @@ def _check_model_modalities( if isinstance(tasks, AbsTask): check_tasks = [tasks] elif isinstance(tasks, Benchmark): - benchmark = cast(Benchmark, tasks) + benchmark = cast("Benchmark", tasks) check_tasks = benchmark.tasks else: - check_tasks = cast(Iterable[AbsTask], tasks) + check_tasks = cast("Iterable[AbsTask]", tasks) warnings, errors = [], [] @@ -342,7 +344,7 @@ def evaluate( # AbsTaskAggregate is a special case where we have to run multiple tasks and combine the results if isinstance(tasks, AbsTaskAggregate): - aggregated_task = cast(AbsTaskAggregate, tasks) + aggregated_task = cast("AbsTaskAggregate", tasks) results = evaluate( model, aggregated_task.metadata.tasks, @@ -365,7 +367,7 @@ def evaluate( if isinstance(tasks, AbsTask): task = tasks else: - tasks = cast(Iterable[AbsTask], tasks) + tasks = cast("Iterable[AbsTask]", tasks) evaluate_results = [] exceptions = [] tasks_tqdm = tqdm( diff --git a/mteb/filter_tasks.py b/mteb/filter_tasks.py index ea0f5cc0f8..7d3b5c4c22 100644 --- a/mteb/filter_tasks.py +++ b/mteb/filter_tasks.py @@ -1,19 +1,24 @@ """This script contains functions that are used to get an overview of the MTEB benchmark.""" +from __future__ import annotations + import logging -from collections.abc import Iterable, Sequence -from typing import overload +from typing import TYPE_CHECKING, overload -from mteb.abstasks import ( - AbsTask, -) from mteb.abstasks.aggregated_task import AbsTaskAggregate -from mteb.abstasks.task_metadata import TaskCategory, TaskDomain, TaskType from mteb.languages import ( ISO_TO_LANGUAGE, ISO_TO_SCRIPT, ) -from mteb.types import Modalities + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from mteb.abstasks import ( + AbsTask, + ) + from mteb.abstasks.task_metadata import TaskCategory, TaskDomain, TaskType + from mteb.types import Modalities logger = logging.getLogger(__name__) diff --git a/mteb/get_tasks.py b/mteb/get_tasks.py index 1c4efcf226..ea5ce1fa5e 100644 --- a/mteb/get_tasks.py +++ b/mteb/get_tasks.py @@ -1,20 +1,25 @@ """This script contains functions that are used to get an overview of the MTEB benchmark.""" +from __future__ import annotations + import difflib import logging import warnings from collections import Counter, defaultdict -from collections.abc import Iterable, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import pandas as pd from mteb.abstasks import ( AbsTask, ) -from mteb.abstasks.task_metadata import TaskCategory, TaskDomain, TaskType from mteb.filter_tasks import filter_tasks -from mteb.types import Modalities + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from mteb.abstasks.task_metadata import TaskCategory, TaskDomain, TaskType + from mteb.types import Modalities logger = logging.getLogger(__name__) diff --git a/mteb/languages/language_scripts.py b/mteb/languages/language_scripts.py index 3cf48b9aa8..645e3d6d75 100644 --- a/mteb/languages/language_scripts.py +++ b/mteb/languages/language_scripts.py @@ -1,10 +1,15 @@ -from collections.abc import Iterable, Sequence -from dataclasses import dataclass +from __future__ import annotations -from typing_extensions import Self +from dataclasses import dataclass +from typing import TYPE_CHECKING from mteb.languages.check_language_code import check_language_code +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from typing_extensions import Self + @dataclass class LanguageScripts: diff --git a/mteb/leaderboard/app.py b/mteb/leaderboard/app.py index 838056bc85..6348601e95 100644 --- a/mteb/leaderboard/app.py +++ b/mteb/leaderboard/app.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import json import logging @@ -5,15 +7,14 @@ import time import warnings from pathlib import Path -from typing import Literal, get_args +from typing import TYPE_CHECKING, Literal, get_args from urllib.parse import urlencode import cachetools import gradio as gr -import pandas as pd +import pandas as pd # noqa: TC002 # gradio tries to validate typehints import mteb -from mteb import BenchmarkResults from mteb.benchmarks.benchmark import RtebBenchmark from mteb.cache import ResultCache from mteb.leaderboard.benchmark_selector import ( @@ -31,6 +32,9 @@ from mteb.leaderboard.text_segments import ACKNOWLEDGEMENT, FAQ from mteb.models.model_meta import MODEL_TYPES +if TYPE_CHECKING: + from mteb import BenchmarkResults + logger = logging.getLogger(__name__) diff --git a/mteb/leaderboard/table.py b/mteb/leaderboard/table.py index bb9663da63..23048c3b7a 100644 --- a/mteb/leaderboard/table.py +++ b/mteb/leaderboard/table.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import gradio as gr import matplotlib.pyplot as plt import numpy as np @@ -5,8 +9,9 @@ from matplotlib.colors import LinearSegmentedColormap from pandas.api.types import is_numeric_dtype -from mteb.benchmarks.benchmark import Benchmark -from mteb.results.benchmark_results import BenchmarkResults +if TYPE_CHECKING: + from mteb.benchmarks.benchmark import Benchmark + from mteb.results.benchmark_results import BenchmarkResults def _borda_count(scores: pd.Series) -> pd.Series: diff --git a/mteb/load_results.py b/mteb/load_results.py index c306423bd5..47f362b4bc 100644 --- a/mteb/load_results.py +++ b/mteb/load_results.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import json import logging import sys -from collections.abc import Iterable, Sequence -from pathlib import Path +from typing import TYPE_CHECKING from mteb.abstasks.abstask import AbsTask from mteb.models.model_meta import ModelMeta from mteb.results import BenchmarkResults, ModelResult, TaskResult -from mteb.types import ModelName, Revision + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from pathlib import Path + + from mteb.types import ModelName, Revision if sys.version_info >= (3, 13): from warnings import deprecated diff --git a/mteb/models/abs_encoder.py b/mteb/models/abs_encoder.py index f0b55ca8ad..52c81c4797 100644 --- a/mteb/models/abs_encoder.py +++ b/mteb/models/abs_encoder.py @@ -1,14 +1,12 @@ +from __future__ import annotations + import logging import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence -from typing import Any, Literal, cast, get_args, overload - -from torch.utils.data import DataLoader -from typing_extensions import Unpack +from typing import TYPE_CHECKING, Any, Literal, cast, get_args, overload import mteb -from mteb.abstasks.task_metadata import TaskMetadata, TaskType +from mteb.abstasks.task_metadata import TaskType from mteb.similarity_functions import ( cos_sim, dot_score, @@ -18,13 +16,25 @@ pairwise_max_sim, ) from mteb.types import ( - Array, - BatchedInput, - EncodeKwargs, PromptType, ) -from .model_meta import ModelMeta, ScoringFunction +from .model_meta import ScoringFunction + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from torch.utils.data import DataLoader + from typing_extensions import Unpack + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import ( + Array, + BatchedInput, + EncodeKwargs, + ) + + from .model_meta import ModelMeta logger = logging.getLogger(__name__) @@ -314,7 +324,7 @@ def similarity(self, embeddings1: Array, embeddings2: Array) -> Array: ): arr = self.model.similarity(embeddings1, embeddings2) # We assume that the model returns an Array-like object: - arr = cast(Array, arr) + arr = cast("Array", arr) return arr return cos_sim(embeddings1, embeddings2) if self.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE: @@ -352,7 +362,7 @@ def similarity_pairwise( ): arr = self.model.similarity_pairwise(embeddings1, embeddings2) # We assume that the model returns an Array-like object: - arr = cast(Array, arr) + arr = cast("Array", arr) return arr return pairwise_cos_sim(embeddings1, embeddings2) if self.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE: diff --git a/mteb/models/cache_wrappers/cache_backend_protocol.py b/mteb/models/cache_wrappers/cache_backend_protocol.py index b194b044d8..ef10e79686 100644 --- a/mteb/models/cache_wrappers/cache_backend_protocol.py +++ b/mteb/models/cache_wrappers/cache_backend_protocol.py @@ -1,9 +1,11 @@ from __future__ import annotations -from pathlib import Path -from typing import Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable -import numpy as np +if TYPE_CHECKING: + from pathlib import Path + + import numpy as np @runtime_checkable diff --git a/mteb/models/cache_wrappers/cache_backends/_hash_utils.py b/mteb/models/cache_wrappers/cache_backends/_hash_utils.py index f86cfb5702..80511e6395 100644 --- a/mteb/models/cache_wrappers/cache_backends/_hash_utils.py +++ b/mteb/models/cache_wrappers/cache_backends/_hash_utils.py @@ -1,6 +1,12 @@ +from __future__ import annotations + import hashlib -from collections.abc import Mapping -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Mapping + + from PIL import Image def _hash_item(item: Mapping[str, Any]) -> str: @@ -10,8 +16,6 @@ def _hash_item(item: Mapping[str, Any]) -> str: item_hash = hashlib.sha256(item_text.encode()).hexdigest() if "image" in item: - from PIL import Image - image: Image.Image = item["image"] item_hash += hashlib.sha256(image.tobytes()).hexdigest() diff --git a/mteb/models/cache_wrappers/cache_backends/faiss_cache.py b/mteb/models/cache_wrappers/cache_backends/faiss_cache.py index a5cce688ab..edcdd571c1 100644 --- a/mteb/models/cache_wrappers/cache_backends/faiss_cache.py +++ b/mteb/models/cache_wrappers/cache_backends/faiss_cache.py @@ -1,16 +1,22 @@ +from __future__ import annotations + import json import logging import warnings from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from mteb._requires_package import requires_package -from mteb.types import BatchedInput from ._hash_utils import _hash_item +if TYPE_CHECKING: + import faiss + + from mteb.types import BatchedInput + logger = logging.getLogger(__name__) @@ -24,7 +30,6 @@ def __init__(self, directory: str | Path): "FAISS-based vector cache", install_instruction="pip install mteb[faiss-cpu]", ) - import faiss self.directory = Path(directory) self.directory.mkdir(parents=True, exist_ok=True) diff --git a/mteb/models/cache_wrappers/cache_wrapper.py b/mteb/models/cache_wrappers/cache_wrapper.py index 4807385074..1acd09ea19 100644 --- a/mteb/models/cache_wrappers/cache_wrapper.py +++ b/mteb/models/cache_wrappers/cache_wrapper.py @@ -1,21 +1,26 @@ +from __future__ import annotations + import logging from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch from datasets import Dataset -from torch.utils.data import DataLoader from mteb._create_dataloaders import create_dataloader -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.models.cache_wrappers.cache_backend_protocol import ( - CacheBackendProtocol, -) from mteb.models.cache_wrappers.cache_backends.numpy_cache import NumpyCache -from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import EncoderProtocol -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models.cache_wrappers.cache_backend_protocol import ( + CacheBackendProtocol, + ) + from mteb.models.model_meta import ModelMeta + from mteb.models.models_protocols import EncoderProtocol + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/get_model_meta.py b/mteb/models/get_model_meta.py index 0f6b6dd33c..72e3d6de79 100644 --- a/mteb/models/get_model_meta.py +++ b/mteb/models/get_model_meta.py @@ -1,15 +1,22 @@ +from __future__ import annotations + import difflib import logging -from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any -from mteb.abstasks import AbsTask from mteb.models import ( ModelMeta, - MTEBModels, ) from mteb.models.model_implementations import MODEL_REGISTRY +if TYPE_CHECKING: + from collections.abc import Iterable + + from mteb.abstasks import AbsTask + from mteb.models import ( + MTEBModels, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/models/instruct_wrapper.py b/mteb/models/instruct_wrapper.py index 0dc65e7509..b15c17820c 100644 --- a/mteb/models/instruct_wrapper.py +++ b/mteb/models/instruct_wrapper.py @@ -1,16 +1,24 @@ +from __future__ import annotations + import logging -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType from .abs_encoder import AbsEncoder +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput + + logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/align_models.py b/mteb/models/model_implementations/align_models.py index c173c7681b..0dfdfe500c 100644 --- a/mteb/models/model_implementations/align_models.py +++ b/mteb/models/model_implementations/align_models.py @@ -1,13 +1,18 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType class ALIGNModel(AbsEncoder): diff --git a/mteb/models/model_implementations/bedrock_models.py b/mteb/models/model_implementations/bedrock_models.py index 80ae610d74..fd0022ec7a 100644 --- a/mteb/models/model_implementations/bedrock_models.py +++ b/mteb/models/model_implementations/bedrock_models.py @@ -1,20 +1,30 @@ +from __future__ import annotations + import json import logging import re -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType -from .cohere_models import model_prompts as cohere_model_prompts -from .cohere_models import supported_languages as cohere_supported_languages +from .cohere_models import ( + model_prompts as cohere_model_prompts, +) +from .cohere_models import ( + supported_languages as cohere_supported_languages, +) + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType + logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/blip2_models.py b/mteb/models/model_implementations/blip2_models.py index bb197caa2c..d4e1a91e64 100644 --- a/mteb/models/model_implementations/blip2_models.py +++ b/mteb/models/model_implementations/blip2_models.py @@ -1,14 +1,19 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType BLIP2_CITATION = """@inproceedings{li2023blip2, title={{BLIP-2:} Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models}, diff --git a/mteb/models/model_implementations/blip_models.py b/mteb/models/model_implementations/blip_models.py index 4258b8eb94..b1214e12a7 100644 --- a/mteb/models/model_implementations/blip_models.py +++ b/mteb/models/model_implementations/blip_models.py @@ -1,14 +1,19 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch from torch.nn.functional import normalize -from torch.utils.data import DataLoader from tqdm.auto import tqdm -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType BLIP_CITATION = """@misc{https://doi.org/10.48550/arxiv.2201.12086, doi = {10.48550/ARXIV.2201.12086}, diff --git a/mteb/models/model_implementations/bm25.py b/mteb/models/model_implementations/bm25.py index c0f93f748b..dec2b6d698 100644 --- a/mteb/models/model_implementations/bm25.py +++ b/mteb/models/model_implementations/bm25.py @@ -1,18 +1,23 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING from mteb._create_dataloaders import _create_text_queries_dataloader from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import SearchProtocol -from mteb.types import ( - CorpusDatasetType, - EncodeKwargs, - InstructionDatasetType, - QueryDatasetType, - RetrievalOutputType, - TopRankedDocumentsType, -) + +if TYPE_CHECKING: + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models.models_protocols import SearchProtocol + from mteb.types import ( + CorpusDatasetType, + EncodeKwargs, + InstructionDatasetType, + QueryDatasetType, + RetrievalOutputType, + TopRankedDocumentsType, + ) logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/bmretriever_models.py b/mteb/models/model_implementations/bmretriever_models.py index 8a52b55531..b1cd20f90c 100644 --- a/mteb/models/model_implementations/bmretriever_models.py +++ b/mteb/models/model_implementations/bmretriever_models.py @@ -1,5 +1,6 @@ -from collections.abc import Callable -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch from sentence_transformers import SentenceTransformer @@ -9,6 +10,9 @@ from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.types import PromptType +if TYPE_CHECKING: + from collections.abc import Callable + def instruction_template( instruction: str, prompt_type: PromptType | None = None diff --git a/mteb/models/model_implementations/cde_models.py b/mteb/models/model_implementations/cde_models.py index 7158bc2840..c0119397c6 100644 --- a/mteb/models/model_implementations/cde_models.py +++ b/mteb/models/model_implementations/cde_models.py @@ -1,27 +1,31 @@ +from __future__ import annotations + import logging -from collections.abc import Sequence from typing import TYPE_CHECKING, Any import numpy as np import torch -from torch.utils.data import DataLoader import mteb from mteb._create_dataloaders import _corpus_to_dict -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.models.models_protocols import PromptType from mteb.models.sentence_transformer_wrapper import SentenceTransformerEncoderWrapper -from mteb.types import Array, BatchedInput +from mteb.types import PromptType from .bge_models import bge_full_data if TYPE_CHECKING: + from collections.abc import Sequence + + from torch.utils.data import DataLoader + from mteb.abstasks import ( AbsTaskClassification, AbsTaskRetrieval, AbsTaskSummarization, ) + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) CDE_CITATION = """@misc{morris2024contextualdocumentembeddings, diff --git a/mteb/models/model_implementations/clip_models.py b/mteb/models/model_implementations/clip_models.py index 65e07cce4b..03f3af9a1b 100644 --- a/mteb/models/model_implementations/clip_models.py +++ b/mteb/models/model_implementations/clip_models.py @@ -1,13 +1,18 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType class CLIPModel(AbsEncoder): diff --git a/mteb/models/model_implementations/cohere_models.py b/mteb/models/model_implementations/cohere_models.py index a110e7d6ba..5822fb3c03 100644 --- a/mteb/models/model_implementations/cohere_models.py +++ b/mteb/models/model_implementations/cohere_models.py @@ -1,18 +1,24 @@ +from __future__ import annotations + import logging import time from functools import wraps -from typing import Any, Literal, get_args +from typing import TYPE_CHECKING, Any, Literal, get_args import numpy as np import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/cohere_v.py b/mteb/models/model_implementations/cohere_v.py index bbf06d9b4d..b6c89bf754 100644 --- a/mteb/models/model_implementations/cohere_v.py +++ b/mteb/models/model_implementations/cohere_v.py @@ -1,15 +1,15 @@ +from __future__ import annotations + import base64 import io import os import time -from typing import Any, Literal, get_args +from typing import TYPE_CHECKING, Any, Literal, get_args import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies, requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models import ModelMeta from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_implementations.cohere_models import ( @@ -18,7 +18,12 @@ retry_with_rate_limit, ) from mteb.models.model_meta import ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType def _post_process_embeddings( diff --git a/mteb/models/model_implementations/colpali_models.py b/mteb/models/model_implementations/colpali_models.py index 46324cb890..329615eb9f 100644 --- a/mteb/models/model_implementations/colpali_models.py +++ b/mteb/models/model_implementations/colpali_models.py @@ -4,20 +4,21 @@ from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import ( requires_image_dependencies, requires_package, ) -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType if TYPE_CHECKING: from PIL import Image + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/colqwen_models.py b/mteb/models/model_implementations/colqwen_models.py index e285a80c6b..a41b46a8eb 100644 --- a/mteb/models/model_implementations/colqwen_models.py +++ b/mteb/models/model_implementations/colqwen_models.py @@ -1,18 +1,23 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import ( requires_image_dependencies, requires_package, ) -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType from .colpali_models import ( COLPALI_CITATION, diff --git a/mteb/models/model_implementations/conan_models.py b/mteb/models/model_implementations/conan_models.py index b2f175f827..fc70397ea6 100644 --- a/mteb/models/model_implementations/conan_models.py +++ b/mteb/models/model_implementations/conan_models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import json import logging @@ -5,20 +7,24 @@ import random import string import time -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import requests -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType from .bge_models import bge_full_data from .e5_instruct import E5_MISTRAL_TRAINING_DATA +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType + + conan_zh_datasets = { "BQ", "LCQMC", diff --git a/mteb/models/model_implementations/dino_models.py b/mteb/models/model_implementations/dino_models.py index b979252c79..00c9092fee 100644 --- a/mteb/models/model_implementations/dino_models.py +++ b/mteb/models/model_implementations/dino_models.py @@ -1,13 +1,18 @@ -from typing import Any, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType class DINOModel(AbsEncoder): diff --git a/mteb/models/model_implementations/e5_v.py b/mteb/models/model_implementations/e5_v.py index 9833f9009c..a09e9a5ba1 100644 --- a/mteb/models/model_implementations/e5_v.py +++ b/mteb/models/model_implementations/e5_v.py @@ -1,14 +1,19 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch from packaging import version -from torch.utils.data import DataLoader from tqdm.auto import tqdm -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType E5_V_TRANSFORMERS_VERSION = ( "4.44.2" # Issue 1647: Only works with transformers==4.44.2. diff --git a/mteb/models/model_implementations/eagerworks_models.py b/mteb/models/model_implementations/eagerworks_models.py index a8603a5f58..ae6470d471 100644 --- a/mteb/models/model_implementations/eagerworks_models.py +++ b/mteb/models/model_implementations/eagerworks_models.py @@ -1,17 +1,23 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import ( requires_image_dependencies, requires_package, ) -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput class EagerEmbedV1Wrapper(AbsEncoder): diff --git a/mteb/models/model_implementations/evaclip_models.py b/mteb/models/model_implementations/evaclip_models.py index ab61db6e96..0cc04c0951 100644 --- a/mteb/models/model_implementations/evaclip_models.py +++ b/mteb/models/model_implementations/evaclip_models.py @@ -1,15 +1,20 @@ +from __future__ import annotations + from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType EVA_CLIP_CITATION = """@article{EVA-CLIP, title={EVA-CLIP: Improved Training Techniques for CLIP at Scale}, diff --git a/mteb/models/model_implementations/gme_v_models.py b/mteb/models/model_implementations/gme_v_models.py index eea5aeae2b..d54c005dc3 100644 --- a/mteb/models/model_implementations/gme_v_models.py +++ b/mteb/models/model_implementations/gme_v_models.py @@ -6,16 +6,18 @@ from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.autonotebook import tqdm -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType if TYPE_CHECKING: from PIL import Image + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/google_models.py b/mteb/models/model_implementations/google_models.py index 1d88fa44b1..ee98344ab1 100644 --- a/mteb/models/model_implementations/google_models.py +++ b/mteb/models/model_implementations/google_models.py @@ -1,17 +1,23 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import numpy as np from packaging.version import Version -from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import __version__ as transformers_version from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models import sentence_transformers_loader from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput MULTILINGUAL_EVALUATED_LANGUAGES = [ "arb-Arab", diff --git a/mteb/models/model_implementations/granite_vision_embedding_models.py b/mteb/models/model_implementations/granite_vision_embedding_models.py index d7a2d810b8..facb0ac3d7 100644 --- a/mteb/models/model_implementations/granite_vision_embedding_models.py +++ b/mteb/models/model_implementations/granite_vision_embedding_models.py @@ -4,20 +4,21 @@ from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import ( requires_image_dependencies, ) -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType - -logger = logging.getLogger(__name__) if TYPE_CHECKING: from PIL import Image + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType + +logger = logging.getLogger(__name__) class GraniteVisionEmbeddingWrapper: diff --git a/mteb/models/model_implementations/hinvec_models.py b/mteb/models/model_implementations/hinvec_models.py index 70b32accd4..f891495a01 100644 --- a/mteb/models/model_implementations/hinvec_models.py +++ b/mteb/models/model_implementations/hinvec_models.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING from mteb.models.model_meta import ModelMeta from mteb.models.sentence_transformer_wrapper import sentence_transformers_loader -from mteb.types import PromptType +if TYPE_CHECKING: + from mteb.types import PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/jasper_models.py b/mteb/models/model_implementations/jasper_models.py index b3f3651edc..03dd820d29 100644 --- a/mteb/models/model_implementations/jasper_models.py +++ b/mteb/models/model_implementations/jasper_models.py @@ -1,11 +1,10 @@ +from __future__ import annotations + import logging -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_implementations.bge_models import ( @@ -17,7 +16,15 @@ from mteb.models.model_implementations.nvidia_models import nvidia_training_datasets from mteb.models.model_implementations.qzhou_models import qzhou_training_data from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/jina_clip.py b/mteb/models/model_implementations/jina_clip.py index edc1eb5c97..4c86ebddbc 100644 --- a/mteb/models/model_implementations/jina_clip.py +++ b/mteb/models/model_implementations/jina_clip.py @@ -1,15 +1,20 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_implementations.colpali_models import COLPALI_TRAINING_DATA from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType JINA_CLIP_CITATION = """@article{koukounas2024jinaclip, title={Jina CLIP: Your CLIP Model Is Also Your Text Retriever}, diff --git a/mteb/models/model_implementations/jina_models.py b/mteb/models/model_implementations/jina_models.py index eacb8d943d..b016d1691d 100644 --- a/mteb/models/model_implementations/jina_models.py +++ b/mteb/models/model_implementations/jina_models.py @@ -1,14 +1,13 @@ +from __future__ import annotations + import logging from collections import defaultdict -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import numpy as np import torch -from sentence_transformers import CrossEncoder -from torch.utils.data import DataLoader from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.languages import PROGRAMMING_LANGS from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction @@ -16,7 +15,13 @@ CrossEncoderWrapper, SentenceTransformerEncoderWrapper, ) -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from sentence_transformers import CrossEncoder + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/kalm_models.py b/mteb/models/model_implementations/kalm_models.py index 3edcec2083..1613d53efc 100644 --- a/mteb/models/model_implementations/kalm_models.py +++ b/mteb/models/model_implementations/kalm_models.py @@ -1,14 +1,20 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_meta import ModelMeta from mteb.models.sentence_transformer_wrapper import sentence_transformers_loader -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) @@ -907,23 +913,23 @@ def encode( adapted_from="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v2", superseded_by=None, citation="""@misc{zhao2025kalmembeddingv2, - title={KaLM-Embedding-V2: Superior Training Techniques and Data Inspire A Versatile Embedding Model}, + title={KaLM-Embedding-V2: Superior Training Techniques and Data Inspire A Versatile Embedding Model}, author={Xinping Zhao and Xinshuo Hu and Zifei Shan and Shouzheng Huang and Yao Zhou and Xin Zhang and Zetian Sun and Zhenyu Liu and Dongfang Li and Xinyuan Wei and Youcheng Pan and Yang Xiang and Meishan Zhang and Haofen Wang and Jun Yu and Baotian Hu and Min Zhang}, year={2025}, eprint={2506.20923}, archivePrefix={arXiv}, primaryClass={cs.CL}, - url={https://arxiv.org/abs/2506.20923}, + url={https://arxiv.org/abs/2506.20923}, } @misc{hu2025kalmembedding, - title={KaLM-Embedding: Superior Training Data Brings A Stronger Embedding Model}, + title={KaLM-Embedding: Superior Training Data Brings A Stronger Embedding Model}, author={Xinshuo Hu and Zifei Shan and Xinping Zhao and Zetian Sun and Zhenyu Liu and Dongfang Li and Shaolin Ye and Xinyuan Wei and Qian Chen and Baotian Hu and Haofen Wang and Jun Yu and Min Zhang}, year={2025}, eprint={2501.01028}, archivePrefix={arXiv}, primaryClass={cs.CL}, - url={https://arxiv.org/abs/2501.01028}, + url={https://arxiv.org/abs/2501.01028}, }""", ) @@ -954,22 +960,22 @@ def encode( public_training_data=None, training_datasets=KaLM_Embedding_gemma_3_12b_training_data, citation="""@misc{zhao2025kalmembeddingv2, - title={KaLM-Embedding-V2: Superior Training Techniques and Data Inspire A Versatile Embedding Model}, + title={KaLM-Embedding-V2: Superior Training Techniques and Data Inspire A Versatile Embedding Model}, author={Xinping Zhao and Xinshuo Hu and Zifei Shan and Shouzheng Huang and Yao Zhou and Xin Zhang and Zetian Sun and Zhenyu Liu and Dongfang Li and Xinyuan Wei and Youcheng Pan and Yang Xiang and Meishan Zhang and Haofen Wang and Jun Yu and Baotian Hu and Min Zhang}, year={2025}, eprint={2506.20923}, archivePrefix={arXiv}, primaryClass={cs.CL}, - url={https://arxiv.org/abs/2506.20923}, + url={https://arxiv.org/abs/2506.20923}, } @misc{hu2025kalmembedding, - title={KaLM-Embedding: Superior Training Data Brings A Stronger Embedding Model}, + title={KaLM-Embedding: Superior Training Data Brings A Stronger Embedding Model}, author={Xinshuo Hu and Zifei Shan and Xinping Zhao and Zetian Sun and Zhenyu Liu and Dongfang Li and Shaolin Ye and Xinyuan Wei and Qian Chen and Baotian Hu and Haofen Wang and Jun Yu and Min Zhang}, year={2025}, eprint={2501.01028}, archivePrefix={arXiv}, primaryClass={cs.CL}, - url={https://arxiv.org/abs/2501.01028}, + url={https://arxiv.org/abs/2501.01028}, }""", ) diff --git a/mteb/models/model_implementations/linq_models.py b/mteb/models/model_implementations/linq_models.py index 2108296319..237eba5ee5 100644 --- a/mteb/models/model_implementations/linq_models.py +++ b/mteb/models/model_implementations/linq_models.py @@ -1,11 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import torch from mteb.models.instruct_wrapper import instruct_wrapper from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import PromptType from .e5_instruct import E5_MISTRAL_TRAINING_DATA +if TYPE_CHECKING: + from mteb.types import PromptType LINQ_EMBED_MISTRAL_CITATION = """@misc{LinqAIResearch2024, title={Linq-Embed-Mistral:Elevating Text Retrieval with Improved GPT Data Through Task-Specific Control and Quality Refinement}, author={Junseong Kim and Seolhwa Lee and Jihoon Kwon and Sangmo Gu and Yejin Kim and Minkyung Cho and Jy-yong Sohn and Chanyeol Choi}, diff --git a/mteb/models/model_implementations/listconranker.py b/mteb/models/model_implementations/listconranker.py index 8610d1d52e..76df6f2239 100644 --- a/mteb/models/model_implementations/listconranker.py +++ b/mteb/models/model_implementations/listconranker.py @@ -1,14 +1,19 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta -from mteb.types import BatchedInput, PromptType from .rerankers_custom import RerankerWrapper +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import BatchedInput, PromptType + LISTCONRANKER_CITATION = """@article{liu2025listconranker, title={ListConRanker: A Contrastive Text Reranker with Listwise Encoding}, author={Liu, Junlong and Ma, Yue and Zhao, Ruihui and Zheng, Junhao and Ma, Qianli and Kang, Yangyang}, diff --git a/mteb/models/model_implementations/llm2clip_models.py b/mteb/models/model_implementations/llm2clip_models.py index a872ae9bc6..6398b650a5 100644 --- a/mteb/models/model_implementations/llm2clip_models.py +++ b/mteb/models/model_implementations/llm2clip_models.py @@ -1,15 +1,20 @@ +from __future__ import annotations + from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies, requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType LLM2CLIP_CITATION = """@misc{huang2024llm2clippowerfullanguagemodel, title={LLM2CLIP: Powerful Language Model Unlock Richer Visual Representation}, diff --git a/mteb/models/model_implementations/llm2vec_models.py b/mteb/models/model_implementations/llm2vec_models.py index bb52db9568..4699016597 100644 --- a/mteb/models/model_implementations/llm2vec_models.py +++ b/mteb/models/model_implementations/llm2vec_models.py @@ -1,16 +1,22 @@ +from __future__ import annotations + import logging -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from mteb._requires_package import requires_package, suggest_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.models.models_protocols import EncoderProtocol -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models.models_protocols import EncoderProtocol + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/mcinext_models.py b/mteb/models/model_implementations/mcinext_models.py index 1adc2a606c..9c3ebc7f78 100644 --- a/mteb/models/model_implementations/mcinext_models.py +++ b/mteb/models/model_implementations/mcinext_models.py @@ -1,16 +1,19 @@ +from __future__ import annotations + import logging import os import time import warnings -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import requests from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta -from mteb.types import PromptType +if TYPE_CHECKING: + from mteb.types import PromptType logger = logging.getLogger(__name__) HAKIM_CITATION = """@article{sarmadi2025hakim, diff --git a/mteb/models/model_implementations/moco_models.py b/mteb/models/model_implementations/moco_models.py index 2d1ab1a1fe..71e0144a2a 100644 --- a/mteb/models/model_implementations/moco_models.py +++ b/mteb/models/model_implementations/moco_models.py @@ -1,14 +1,19 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies, requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType MOCOV3_CITATION = """@Article{chen2021mocov3, author = {Xinlei Chen* and Saining Xie* and Kaiming He}, diff --git a/mteb/models/model_implementations/mod_models.py b/mteb/models/model_implementations/mod_models.py index 64e16b90fb..bd04cb7687 100644 --- a/mteb/models/model_implementations/mod_models.py +++ b/mteb/models/model_implementations/mod_models.py @@ -1,6 +1,6 @@ from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import PromptType +from mteb.types import PromptType def instruction_template( diff --git a/mteb/models/model_implementations/model2vec_models.py b/mteb/models/model_implementations/model2vec_models.py index e502194c60..1b68b9a27a 100644 --- a/mteb/models/model_implementations/model2vec_models.py +++ b/mteb/models/model_implementations/model2vec_models.py @@ -1,17 +1,23 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -from torch.utils.data import DataLoader from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType from .bge_models import bge_training_data +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType + + logger = logging.getLogger(__name__) MODEL2VEC_CITATION = """@software{minishlab2024model2vec, diff --git a/mteb/models/model_implementations/no_instruct_sentence_models.py b/mteb/models/model_implementations/no_instruct_sentence_models.py index 1b6e1d49a9..e4c4bdd9a7 100644 --- a/mteb/models/model_implementations/no_instruct_sentence_models.py +++ b/mteb/models/model_implementations/no_instruct_sentence_models.py @@ -1,15 +1,22 @@ -from collections.abc import Generator +from __future__ import annotations + from itertools import islice -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from collections.abc import Generator + + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput # https://docs.python.org/3/library/itertools.html#itertools.batched diff --git a/mteb/models/model_implementations/nomic_models.py b/mteb/models/model_implementations/nomic_models.py index bd3b1cd65c..7a12fefb88 100644 --- a/mteb/models/model_implementations/nomic_models.py +++ b/mteb/models/model_implementations/nomic_models.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F from packaging.version import Version -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta, ScoringFunction from mteb.models.sentence_transformer_wrapper import SentenceTransformerEncoderWrapper -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/nomic_models_vision.py b/mteb/models/model_implementations/nomic_models_vision.py index 8bf721a1b8..c86e44893b 100644 --- a/mteb/models/model_implementations/nomic_models_vision.py +++ b/mteb/models/model_implementations/nomic_models_vision.py @@ -4,17 +4,18 @@ import torch import torch.nn.functional as F -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType if TYPE_CHECKING: from PIL import Image + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType NOMIC_EMBED_VISION_CITATION = """@article{nussbaum2024nomicembedvision, title={Nomic Embed Vision: Expanding the Latent Space}, diff --git a/mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py b/mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py index 532d34a58a..a730bdd0fc 100644 --- a/mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +++ b/mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py @@ -1,14 +1,18 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch from packaging.version import Version from torch.utils.data import DataLoader from transformers import __version__ as transformers_version -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType LLAMA_NEMORETRIEVER_CITATION = """@misc{xu2025llamanemoretrievercolembedtopperforming, title={Llama Nemoretriever Colembed: Top-Performing Text-Image Retrieval Model}, diff --git a/mteb/models/model_implementations/nvidia_models.py b/mteb/models/model_implementations/nvidia_models.py index 05c567fb41..21208bcc1b 100644 --- a/mteb/models/model_implementations/nvidia_models.py +++ b/mteb/models/model_implementations/nvidia_models.py @@ -1,11 +1,11 @@ +from __future__ import annotations + import logging -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F from packaging.version import Version -from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModel, AutoTokenizer from transformers import __version__ as transformers_version @@ -16,7 +16,15 @@ from mteb.models.abs_encoder import AbsEncoder from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader + + from mteb import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/octen_models.py b/mteb/models/model_implementations/octen_models.py index 87ed0c4df2..7c104a9795 100644 --- a/mteb/models/model_implementations/octen_models.py +++ b/mteb/models/model_implementations/octen_models.py @@ -1,6 +1,6 @@ from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import PromptType +from mteb.types import PromptType def instruction_template( diff --git a/mteb/models/model_implementations/openai_models.py b/mteb/models/model_implementations/openai_models.py index fc10683921..aa8d1ec29b 100644 --- a/mteb/models/model_implementations/openai_models.py +++ b/mteb/models/model_implementations/openai_models.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import logging -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import numpy as np -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/openclip_models.py b/mteb/models/model_implementations/openclip_models.py index f49c4865b0..0211d97339 100644 --- a/mteb/models/model_implementations/openclip_models.py +++ b/mteb/models/model_implementations/openclip_models.py @@ -1,14 +1,19 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies, requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType OPENCLIP_CITATION = """@inproceedings{cherti2023reproducible, title={Reproducible scaling laws for contrastive language-image learning}, diff --git a/mteb/models/model_implementations/opensearch_neural_sparse_models.py b/mteb/models/model_implementations/opensearch_neural_sparse_models.py index fdc13414bc..891835e374 100644 --- a/mteb/models/model_implementations/opensearch_neural_sparse_models.py +++ b/mteb/models/model_implementations/opensearch_neural_sparse_models.py @@ -1,12 +1,18 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput v2_training_data = { "MSMARCO", diff --git a/mteb/models/model_implementations/ops_moa_models.py b/mteb/models/model_implementations/ops_moa_models.py index b9dde34735..5a142ff1a3 100644 --- a/mteb/models/model_implementations/ops_moa_models.py +++ b/mteb/models/model_implementations/ops_moa_models.py @@ -1,8 +1,13 @@ -import numpy as np +from __future__ import annotations + +from typing import TYPE_CHECKING from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta +if TYPE_CHECKING: + from mteb.types import Array + class OPSWrapper(AbsEncoder): def __init__(self, model_name: str, revision: str): @@ -15,7 +20,7 @@ def __init__(self, model_name: str, revision: str): ) self.output_dim = 1536 - def encode(self, sentences: list[str], **kwargs) -> np.ndarray: + def encode(self, sentences: list[str], **kwargs) -> Array: embeddings = self.model.encode(sentences, **kwargs) return embeddings[:, : self.output_dim] diff --git a/mteb/models/model_implementations/promptriever_models.py b/mteb/models/model_implementations/promptriever_models.py index 08e9aa7c0a..a04dfed8d7 100644 --- a/mteb/models/model_implementations/promptriever_models.py +++ b/mteb/models/model_implementations/promptriever_models.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import logging -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.models.models_protocols import EncoderProtocol -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models.models_protocols import EncoderProtocol + from mteb.types import Array, BatchedInput, PromptType from .repllama_models import RepLLaMAModel, model_prompts diff --git a/mteb/models/model_implementations/pylate_models.py b/mteb/models/model_implementations/pylate_models.py index ff1752212e..adf52e37cc 100644 --- a/mteb/models/model_implementations/pylate_models.py +++ b/mteb/models/model_implementations/pylate_models.py @@ -1,30 +1,36 @@ +from __future__ import annotations + import heapq import logging import shutil import tempfile from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from mteb._create_dataloaders import ( create_dataloader, ) from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import ( - Array, - BatchedInput, - CorpusDatasetType, - EncodeKwargs, - PromptType, - QueryDatasetType, - RetrievalOutputType, - TopRankedDocumentsType, -) +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import ( + Array, + BatchedInput, + CorpusDatasetType, + EncodeKwargs, + QueryDatasetType, + RetrievalOutputType, + TopRankedDocumentsType, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/qwen3_models.py b/mteb/models/model_implementations/qwen3_models.py index f9536853e2..94e1f0e712 100644 --- a/mteb/models/model_implementations/qwen3_models.py +++ b/mteb/models/model_implementations/qwen3_models.py @@ -1,6 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import EncoderProtocol, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from mteb.models.models_protocols import EncoderProtocol def instruction_template( diff --git a/mteb/models/model_implementations/random_baseline.py b/mteb/models/model_implementations/random_baseline.py index 92dd754dd2..22d565c787 100644 --- a/mteb/models/model_implementations/random_baseline.py +++ b/mteb/models/model_implementations/random_baseline.py @@ -5,18 +5,19 @@ import numpy as np import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta from mteb.similarity_functions import ( select_pairwise_similarity, select_similarity, ) -from mteb.types._encoder_io import Array, BatchedInput, PromptType if TYPE_CHECKING: from PIL import Image + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types._encoder_io import Array, BatchedInput, PromptType def _string_to_vector(text: str | None, size: int) -> np.ndarray: diff --git a/mteb/models/model_implementations/repllama_models.py b/mteb/models/model_implementations/repllama_models.py index 6ca55bb899..485b864670 100644 --- a/mteb/models/model_implementations/repllama_models.py +++ b/mteb/models/model_implementations/repllama_models.py @@ -1,22 +1,29 @@ +from __future__ import annotations + import logging -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch import torch.nn.functional as F -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ( ModelMeta, ScoringFunction, ) -from mteb.models.models_protocols import EncoderProtocol -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.models.models_protocols import EncoderProtocol + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/rerankers_custom.py b/mteb/models/model_implementations/rerankers_custom.py index 5370670605..badf3d0065 100644 --- a/mteb/models/model_implementations/rerankers_custom.py +++ b/mteb/models/model_implementations/rerankers_custom.py @@ -1,16 +1,22 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType from .bge_models import bge_m3_training_data +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType + + logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/rerankers_monot5_based.py b/mteb/models/model_implementations/rerankers_monot5_based.py index 399c2c5ad3..7b01a772c8 100644 --- a/mteb/models/model_implementations/rerankers_monot5_based.py +++ b/mteb/models/model_implementations/rerankers_monot5_based.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType from .rerankers_custom import RerankerWrapper +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType + + logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/salesforce_models.py b/mteb/models/model_implementations/salesforce_models.py index b1a9a1fa67..b33d2b62e2 100644 --- a/mteb/models/model_implementations/salesforce_models.py +++ b/mteb/models/model_implementations/salesforce_models.py @@ -1,12 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from mteb.models.instruct_wrapper import ( InstructSentenceTransformerModel, instruct_wrapper, ) from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import PromptType from .e5_instruct import E5_MISTRAL_TRAINING_DATA +if TYPE_CHECKING: + from mteb.types import PromptType + def instruction_template( instruction: str, prompt_type: PromptType | None = None diff --git a/mteb/models/model_implementations/seed_1_6_embedding_models.py b/mteb/models/model_implementations/seed_1_6_embedding_models.py index 5b56594f4a..238e435529 100644 --- a/mteb/models/model_implementations/seed_1_6_embedding_models.py +++ b/mteb/models/model_implementations/seed_1_6_embedding_models.py @@ -13,16 +13,18 @@ from torch.utils.data import DataLoader from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_implementations.bge_models import bge_chinese_training_data from mteb.models.model_implementations.nvidia_models import nvidia_training_datasets from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType if TYPE_CHECKING: from PIL import Image + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput + logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/seed_1_6_embedding_models_1215.py b/mteb/models/model_implementations/seed_1_6_embedding_models_1215.py index 06541e4d87..40a403b527 100644 --- a/mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +++ b/mteb/models/model_implementations/seed_1_6_embedding_models_1215.py @@ -15,15 +15,18 @@ from tqdm import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_implementations.bge_models import bge_chinese_training_data from mteb.models.model_implementations.nvidia_models import nvidia_training_datasets from mteb.models.model_meta import ModelMeta -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType if TYPE_CHECKING: from PIL import Image + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/seed_models.py b/mteb/models/model_implementations/seed_models.py index 7d7b1e0125..f120a9c6f4 100644 --- a/mteb/models/model_implementations/seed_models.py +++ b/mteb/models/model_implementations/seed_models.py @@ -9,7 +9,7 @@ from mteb._requires_package import requires_package from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import PromptType +from mteb.types import PromptType from .bge_models import bge_chinese_training_data from .nvidia_models import nvidia_training_datasets diff --git a/mteb/models/model_implementations/siglip_models.py b/mteb/models/model_implementations/siglip_models.py index b8191f485c..f28fc8662f 100644 --- a/mteb/models/model_implementations/siglip_models.py +++ b/mteb/models/model_implementations/siglip_models.py @@ -1,13 +1,18 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType SIGLIP_CITATION = """@misc{zhai2023sigmoid, title={Sigmoid Loss for Language Image Pre-Training}, diff --git a/mteb/models/model_implementations/slm_models.py b/mteb/models/model_implementations/slm_models.py index a0f152c6b2..1b66d38ae7 100644 --- a/mteb/models/model_implementations/slm_models.py +++ b/mteb/models/model_implementations/slm_models.py @@ -13,24 +13,27 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import ( requires_image_dependencies, requires_package, ) -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_implementations.colpali_models import ( COLPALI_CITATION, COLPALI_TRAINING_DATA, ) from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/uae_models.py b/mteb/models/model_implementations/uae_models.py index 3a65788b84..b7f4c862cb 100644 --- a/mteb/models/model_implementations/uae_models.py +++ b/mteb/models/model_implementations/uae_models.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta, ScoringFunction from mteb.models.sentence_transformer_wrapper import SentenceTransformerEncoderWrapper -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/vdr_models.py b/mteb/models/model_implementations/vdr_models.py index 89fec7c4b4..b9d542e542 100644 --- a/mteb/models/model_implementations/vdr_models.py +++ b/mteb/models/model_implementations/vdr_models.py @@ -1,6 +1,12 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import PromptType + +if TYPE_CHECKING: + from mteb.types import PromptType def instruction_template( diff --git a/mteb/models/model_implementations/vista_models.py b/mteb/models/model_implementations/vista_models.py index 4cec94c611..0ee7bdc250 100644 --- a/mteb/models/model_implementations/vista_models.py +++ b/mteb/models/model_implementations/vista_models.py @@ -1,14 +1,19 @@ -from typing import Any, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType VISTA_CITATION = """@article{zhou2024vista, title={VISTA: Visualized Text Embedding For Universal Multi-Modal Retrieval}, diff --git a/mteb/models/model_implementations/vlm2vec_models.py b/mteb/models/model_implementations/vlm2vec_models.py index 5c40e25b74..07fd51d086 100644 --- a/mteb/models/model_implementations/vlm2vec_models.py +++ b/mteb/models/model_implementations/vlm2vec_models.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import ( @@ -10,10 +11,14 @@ requires_package, suggest_package, ) -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, PromptType logger = logging.getLogger(__name__) diff --git a/mteb/models/model_implementations/voyage_models.py b/mteb/models/model_implementations/voyage_models.py index c3a299da0f..d262d8b13c 100644 --- a/mteb/models/model_implementations/voyage_models.py +++ b/mteb/models/model_implementations/voyage_models.py @@ -1,16 +1,22 @@ +from __future__ import annotations + import time from functools import wraps -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput VOYAGE_TRAINING_DATA = set( # Self-reported (message from VoyageAI member) diff --git a/mteb/models/model_implementations/voyage_v.py b/mteb/models/model_implementations/voyage_v.py index 21037dd428..814cc6c671 100644 --- a/mteb/models/model_implementations/voyage_v.py +++ b/mteb/models/model_implementations/voyage_v.py @@ -4,17 +4,19 @@ from typing import TYPE_CHECKING, Any, Literal import torch -from torch.utils.data import DataLoader from tqdm.auto import tqdm from mteb._requires_package import requires_image_dependencies, requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.abs_encoder import AbsEncoder from mteb.models.model_meta import ModelMeta, ScoringFunction -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType if TYPE_CHECKING: from PIL import Image + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput logger = logging.getLogger(__name__) @@ -27,6 +29,8 @@ def _downsample_image( Returns: The downsampled image. """ + from PIL.Image import Resampling + width, height = image.size pixels = width * height @@ -42,15 +46,15 @@ def _downsample_image( logger.info( f"Downsampling image from {width}x{height} to {new_width}x{new_height}" ) - return image.resize(new_size, Image.LANCZOS) + return image.resize(new_size, Resampling.LANCZOS) if width > height: if width > 10000: logger.error("Processing extremely wide images.") - return image.resize((10000, height), Image.LANCZOS) + return image.resize((10000, height), Resampling.LANCZOS) else: if height > 10000: logger.error("Processing extremely high images.") - return image.resize((width, 10000), Image.LANCZOS) + return image.resize((width, 10000), Resampling.LANCZOS) return image diff --git a/mteb/models/model_implementations/yuan_models_en.py b/mteb/models/model_implementations/yuan_models_en.py index 7503491859..40512455df 100644 --- a/mteb/models/model_implementations/yuan_models_en.py +++ b/mteb/models/model_implementations/yuan_models_en.py @@ -1,6 +1,6 @@ from mteb.models.instruct_wrapper import InstructSentenceTransformerModel from mteb.models.model_meta import ModelMeta -from mteb.models.models_protocols import PromptType +from mteb.types import PromptType def instruction_template( diff --git a/mteb/models/model_meta.py b/mteb/models/model_meta.py index 3b1b4d4bcf..750e7e36b0 100644 --- a/mteb/models/model_meta.py +++ b/mteb/models/model_meta.py @@ -3,7 +3,7 @@ import json import logging import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable from dataclasses import field from enum import Enum from functools import partial @@ -11,9 +11,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast from huggingface_hub import ( - GitCommitInfo, ModelCard, - ModelCardData, get_safetensors_metadata, hf_hub_download, list_repo_commits, @@ -30,17 +28,24 @@ ) from pydantic import BaseModel, ConfigDict, field_validator, model_validator from transformers import AutoConfig -from typing_extensions import Self from mteb._helpful_enum import HelpfulStrEnum from mteb.languages import check_language_code -from mteb.models.models_protocols import EncoderProtocol, MTEBModels +from mteb.models.models_protocols import MTEBModels from mteb.types import ISOLanguageScript, Licenses, Modalities, StrDate, StrURL if TYPE_CHECKING: + from collections.abc import Sequence + + from huggingface_hub import ( + GitCommitInfo, + ModelCardData, + ) from sentence_transformers import CrossEncoder, SentenceTransformer + from typing_extensions import Self from mteb.abstasks import AbsTask + from mteb.models.models_protocols import EncoderProtocol logger = logging.getLogger(__name__) @@ -479,7 +484,7 @@ def is_zero_shot_on(self, tasks: Sequence[AbsTask] | Sequence[str]) -> bool | No if isinstance(tasks[0], str): benchmark_datasets = set(tasks) else: - tasks = cast(Sequence["AbsTask"], tasks) + tasks = cast("Sequence[AbsTask]", tasks) benchmark_datasets = set() for task in tasks: benchmark_datasets.add(task.metadata.name) @@ -534,7 +539,7 @@ def zero_shot_percentage( if isinstance(tasks[0], str): benchmark_datasets = set(tasks) else: - tasks = cast(Sequence["AbsTask"], tasks) + tasks = cast("Sequence[AbsTask]", tasks) benchmark_datasets = {task.metadata.name for task in tasks} overlap = training_datasets & benchmark_datasets perc_overlap = 100 * (len(overlap) / len(benchmark_datasets)) diff --git a/mteb/models/models_protocols.py b/mteb/models/models_protocols.py index d273ac72dd..b4eb9c6abe 100644 --- a/mteb/models/models_protocols.py +++ b/mteb/models/models_protocols.py @@ -1,22 +1,23 @@ -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from __future__ import annotations -from torch.utils.data import DataLoader -from typing_extensions import Unpack - -from mteb.abstasks.task_metadata import TaskMetadata -from mteb.types import ( - Array, - BatchedInput, - CorpusDatasetType, - EncodeKwargs, - PromptType, - QueryDatasetType, - RetrievalOutputType, - TopRankedDocumentsType, -) +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: + from torch.utils.data import DataLoader + from typing_extensions import Unpack + + from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta + from mteb.types import ( + Array, + BatchedInput, + CorpusDatasetType, + EncodeKwargs, + PromptType, + QueryDatasetType, + RetrievalOutputType, + TopRankedDocumentsType, + ) @runtime_checkable @@ -72,7 +73,7 @@ def search( ... @property - def mteb_model_meta(self) -> "ModelMeta": + def mteb_model_meta(self) -> ModelMeta: """Metadata of the model""" ... @@ -177,7 +178,7 @@ def similarity_pairwise( ... @property - def mteb_model_meta(self) -> "ModelMeta": + def mteb_model_meta(self) -> ModelMeta: """Metadata of the model""" ... @@ -236,7 +237,7 @@ def predict( ... @property - def mteb_model_meta(self) -> "ModelMeta": + def mteb_model_meta(self) -> ModelMeta: """Metadata of the model""" ... diff --git a/mteb/models/search_encoder_index/search_backend_protocol.py b/mteb/models/search_encoder_index/search_backend_protocol.py index 8fd936677f..fbf3870c7c 100644 --- a/mteb/models/search_encoder_index/search_backend_protocol.py +++ b/mteb/models/search_encoder_index/search_backend_protocol.py @@ -1,7 +1,11 @@ -from collections.abc import Callable -from typing import Protocol +from __future__ import annotations -from mteb.types import Array, TopRankedDocumentsType +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from collections.abc import Callable + + from mteb.types import Array, TopRankedDocumentsType class IndexEncoderSearchProtocol(Protocol): diff --git a/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py b/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py index 6234cfc9d6..8a974838a6 100644 --- a/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +++ b/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py @@ -1,14 +1,23 @@ +from __future__ import annotations + import logging import warnings -from collections.abc import Callable +from typing import TYPE_CHECKING import numpy as np import torch from mteb._requires_package import requires_package from mteb.models.model_meta import ScoringFunction -from mteb.models.models_protocols import EncoderProtocol -from mteb.types import Array, TopRankedDocumentsType + +if TYPE_CHECKING: + from collections.abc import Callable + + import faiss + + from mteb.models.models_protocols import EncoderProtocol + from mteb.types import Array, TopRankedDocumentsType + logger = logging.getLogger(__name__) @@ -33,7 +42,6 @@ def __init__(self, model: EncoderProtocol) -> None: install_instruction="pip install mteb[faiss-cpu]", ) - import faiss from faiss import IndexFlatIP, IndexFlatL2 # https://github.com/facebookresearch/faiss/wiki/Faiss-indexes diff --git a/mteb/models/search_wrappers.py b/mteb/models/search_wrappers.py index 4f8479d81f..2a31a5a96d 100644 --- a/mteb/models/search_wrappers.py +++ b/mteb/models/search_wrappers.py @@ -1,28 +1,35 @@ +from __future__ import annotations + import heapq import logging -from typing import Any +from typing import TYPE_CHECKING, Any import torch from datasets import Dataset -from torch.utils.data import DataLoader from mteb._create_dataloaders import ( create_dataloader, ) -from mteb.abstasks.task_metadata import TaskMetadata from mteb.types import ( - Array, - BatchedInput, - CorpusDatasetType, - EncodeKwargs, PromptType, - QueryDatasetType, - RetrievalOutputType, - TopRankedDocumentsType, ) -from .models_protocols import CrossEncoderProtocol, EncoderProtocol -from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import ( + Array, + BatchedInput, + CorpusDatasetType, + EncodeKwargs, + QueryDatasetType, + RetrievalOutputType, + TopRankedDocumentsType, + ) + + from .models_protocols import CrossEncoderProtocol, EncoderProtocol + from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol logger = logging.getLogger(__name__) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index 356e2b65b1..c449c35013 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -7,19 +7,20 @@ import numpy as np import torch from packaging.version import Version -from torch.utils.data import DataLoader -from typing_extensions import Unpack from mteb._log_once import LogOnce from mteb.models import ModelMeta -from mteb.types import Array, BatchedInput, EncodeKwargs, PromptType +from mteb.types import PromptType from .abs_encoder import AbsEncoder if TYPE_CHECKING: from sentence_transformers import CrossEncoder, SentenceTransformer + from torch.utils.data import DataLoader + from typing_extensions import Unpack from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput, EncodeKwargs logger = logging.getLogger(__name__) diff --git a/mteb/models/vllm_wrapper.py b/mteb/models/vllm_wrapper.py index 8f667b9239..f8a6947334 100644 --- a/mteb/models/vllm_wrapper.py +++ b/mteb/models/vllm_wrapper.py @@ -4,23 +4,25 @@ import gc import logging import os -from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal import numpy as np import torch -from torch.utils.data import DataLoader from mteb._requires_package import requires_package -from mteb.abstasks.task_metadata import TaskMetadata from mteb.models import ModelMeta from mteb.models.abs_encoder import AbsEncoder -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader from vllm.config import PoolerConfig # type: ignore[import-not-found] -else: - PoolerConfig = dict[str, Any] + + from mteb.abstasks.task_metadata import TaskMetadata + from mteb.types import Array, BatchedInput + logger = logging.getLogger(__name__) diff --git a/mteb/results/benchmark_results.py b/mteb/results/benchmark_results.py index 5677173c1b..f315ec326b 100644 --- a/mteb/results/benchmark_results.py +++ b/mteb/results/benchmark_results.py @@ -4,34 +4,39 @@ import json import logging import warnings -from collections.abc import Callable, Iterable, Iterator from pathlib import Path -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast import pandas as pd from packaging.version import InvalidVersion, Version from pydantic import BaseModel, ConfigDict -from typing_extensions import Self -from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.task_metadata import ( - TaskDomain, - TaskType, -) from mteb.benchmarks.benchmark import Benchmark from mteb.models import ModelMeta from mteb.models.get_model_meta import get_model_metas -from mteb.types import ( - ISOLanguage, - ISOLanguageScript, - Modalities, - Score, - ScoresDict, - SplitName, -) from .model_result import ModelResult, _aggregate_and_pivot +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator + + from typing_extensions import Self + + from mteb.abstasks.abstask import AbsTask + from mteb.abstasks.task_metadata import ( + TaskDomain, + TaskType, + ) + from mteb.types import ( + ISOLanguage, + ISOLanguageScript, + Modalities, + Score, + ScoresDict, + SplitName, + ) + + logger = logging.getLogger(__name__) @@ -144,7 +149,7 @@ def select_models( raise ValueError("name in ModelMeta is None. It must be a string.") name_rev[name.name] = name.revision else: - name_ = cast(str, name) + name_ = cast("str", name) name_rev[name_] = revision for model_res in self.model_results: diff --git a/mteb/results/model_result.py b/mteb/results/model_result.py index a668b33e9e..16b5f7aa35 100644 --- a/mteb/results/model_result.py +++ b/mteb/results/model_result.py @@ -2,30 +2,36 @@ import logging import warnings -from collections.abc import Callable, Iterable -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np import pandas as pd from pydantic import BaseModel, ConfigDict, Field from typing_extensions import overload -from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.task_metadata import ( - TaskDomain, - TaskType, -) from mteb.types import ( - ISOLanguage, - ISOLanguageScript, Modalities, - Score, - ScoresDict, - SplitName, ) from .task_result import TaskError, TaskResult +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from mteb.abstasks.abstask import AbsTask + from mteb.abstasks.task_metadata import ( + TaskDomain, + TaskType, + ) + from mteb.types import ( + ISOLanguage, + ISOLanguageScript, + Score, + ScoresDict, + SplitName, + ) + + logger = logging.getLogger(__name__) @@ -83,7 +89,7 @@ class ModelResult(BaseModel): model_revision: str | None task_results: list[TaskResult] default_modalities: list[Modalities] = Field( - default_factory=lambda: [cast(Modalities, "text")], alias="modalities" + default_factory=lambda: [cast("Modalities", "text")], alias="modalities" ) model_config = ( ConfigDict( # to free up the name model_* which is otherwise protected @@ -202,8 +208,8 @@ def _get_scores( aggregation = aggregation if aggregation is not None else np.mean else: use_fast = True - aggregation = cast(Callable[[list[Score]], Any], aggregation) - getter = cast(Callable[[ScoresDict], Score], getter) + aggregation = cast("Callable[[list[Score]], Any]", aggregation) + getter = cast("Callable[[ScoresDict], Score]", getter) if format == "wide": scores = {} diff --git a/mteb/results/task_result.py b/mteb/results/task_result.py index 723b924d0c..fcb8d2bef1 100644 --- a/mteb/results/task_result.py +++ b/mteb/results/task_result.py @@ -4,34 +4,40 @@ import logging import warnings from collections import defaultdict -from collections.abc import Callable, Iterable, Mapping from functools import cached_property from importlib.metadata import version -from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from huggingface_hub import EvalResult from packaging.version import Version from pydantic import BaseModel, field_validator -from typing_extensions import Self from mteb import TaskMetadata from mteb._helpful_enum import HelpfulStrEnum from mteb.abstasks import AbsTaskClassification from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.task_metadata import TaskDomain from mteb.languages import LanguageScripts from mteb.models.model_meta import ScoringFunction from mteb.types import ( - HFSubset, - ISOLanguage, - ISOLanguageScript, - Score, ScoresDict, SplitName, ) +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + from pathlib import Path + + from typing_extensions import Self + + from mteb.abstasks.task_metadata import TaskDomain + from mteb.types import ( + HFSubset, + ISOLanguage, + ISOLanguageScript, + Score, + ) + logger = logging.getLogger(__name__) diff --git a/mteb/similarity_functions.py b/mteb/similarity_functions.py index cd5f32abb6..6a3326b4ca 100644 --- a/mteb/similarity_functions.py +++ b/mteb/similarity_functions.py @@ -1,8 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import torch -from mteb.models import EncoderProtocol from mteb.models.model_meta import ScoringFunction -from mteb.types import Array + +if TYPE_CHECKING: + from mteb.models import EncoderProtocol + from mteb.types import Array def _use_torch_compile(): diff --git a/mteb/tasks/aggregated_tasks/eng/cqadupstack_retrieval.py b/mteb/tasks/aggregated_tasks/eng/cqadupstack_retrieval.py index c37b0563fc..bb99c230d9 100644 --- a/mteb/tasks/aggregated_tasks/eng/cqadupstack_retrieval.py +++ b/mteb/tasks/aggregated_tasks/eng/cqadupstack_retrieval.py @@ -1,5 +1,5 @@ -from mteb.abstasks import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.retrieval import ( CQADupstackAndroidRetrieval, CQADupstackEnglishRetrieval, @@ -15,7 +15,7 @@ CQADupstackWordpressRetrieval, ) -task_list_cqa: list[AbsTask] = [ +task_list_cqa = [ CQADupstackAndroidRetrieval(), CQADupstackEnglishRetrieval(), CQADupstackGamingRetrieval(), diff --git a/mteb/tasks/aggregated_tasks/eng/sts17_multilingual_visual_sts_eng.py b/mteb/tasks/aggregated_tasks/eng/sts17_multilingual_visual_sts_eng.py index debaa0cc32..f78d68a6cd 100644 --- a/mteb/tasks/aggregated_tasks/eng/sts17_multilingual_visual_sts_eng.py +++ b/mteb/tasks/aggregated_tasks/eng/sts17_multilingual_visual_sts_eng.py @@ -1,10 +1,10 @@ -from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.sts.multilingual.sts17_multilingual_visual_sts import ( STS17MultilingualVisualSTS, ) -task_list_sts17: list[AbsTask] = [ +task_list_sts17 = [ STS17MultilingualVisualSTS().filter_languages( languages=["eng"], hf_subsets=["en-en"] ) diff --git a/mteb/tasks/aggregated_tasks/eng/sts_benchmark_multilingual_visual_sts_eng.py b/mteb/tasks/aggregated_tasks/eng/sts_benchmark_multilingual_visual_sts_eng.py index 2b2591c927..97c234eb5e 100644 --- a/mteb/tasks/aggregated_tasks/eng/sts_benchmark_multilingual_visual_sts_eng.py +++ b/mteb/tasks/aggregated_tasks/eng/sts_benchmark_multilingual_visual_sts_eng.py @@ -1,10 +1,10 @@ -from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.sts.multilingual.sts_benchmark_multilingual_visual_sts import ( STSBenchmarkMultilingualVisualSTS, ) -task_list_stsb: list[AbsTask] = [ +task_list_stsb = [ STSBenchmarkMultilingualVisualSTS().filter_languages( languages=["eng"], hf_subsets=["en"] ) diff --git a/mteb/tasks/aggregated_tasks/fas/cqadupstack_retrieval_fa.py b/mteb/tasks/aggregated_tasks/fas/cqadupstack_retrieval_fa.py index ccb5f6d339..237ab4a345 100644 --- a/mteb/tasks/aggregated_tasks/fas/cqadupstack_retrieval_fa.py +++ b/mteb/tasks/aggregated_tasks/fas/cqadupstack_retrieval_fa.py @@ -1,5 +1,5 @@ -from mteb.abstasks import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.retrieval import ( CQADupstackAndroidRetrievalFa, CQADupstackEnglishRetrievalFa, @@ -15,7 +15,7 @@ CQADupstackWordpressRetrievalFa, ) -task_list_cqa: list[AbsTask] = [ +task_list_cqa = [ CQADupstackAndroidRetrievalFa(), CQADupstackEnglishRetrievalFa(), CQADupstackGamingRetrievalFa(), diff --git a/mteb/tasks/aggregated_tasks/fas/syn_per_chatbot_conv_sa_classification.py b/mteb/tasks/aggregated_tasks/fas/syn_per_chatbot_conv_sa_classification.py index 2879556e2a..be73c7775d 100644 --- a/mteb/tasks/aggregated_tasks/fas/syn_per_chatbot_conv_sa_classification.py +++ b/mteb/tasks/aggregated_tasks/fas/syn_per_chatbot_conv_sa_classification.py @@ -1,5 +1,5 @@ -from mteb.abstasks import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.classification import ( SynPerChatbotConvSAAnger, SynPerChatbotConvSAFear, @@ -12,7 +12,7 @@ SynPerChatbotConvSASurprise, ) -task_list_cqa: list[AbsTask] = [ +task_list_cqa = [ SynPerChatbotConvSAAnger(), SynPerChatbotConvSASatisfaction(), SynPerChatbotConvSAFriendship(), diff --git a/mteb/tasks/aggregated_tasks/multilingual/sts17_multilingual_vision_sts.py b/mteb/tasks/aggregated_tasks/multilingual/sts17_multilingual_vision_sts.py index bf5b7112e0..09d37ea5c5 100644 --- a/mteb/tasks/aggregated_tasks/multilingual/sts17_multilingual_vision_sts.py +++ b/mteb/tasks/aggregated_tasks/multilingual/sts17_multilingual_vision_sts.py @@ -1,10 +1,10 @@ -from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.sts.multilingual.sts17_multilingual_visual_sts import ( STS17MultilingualVisualSTS, ) -task_list_sts17_multi: list[AbsTask] = [ +task_list_sts17_multi = [ STS17MultilingualVisualSTS().filter_languages( languages=["ara", "eng", "spa", "kor"], hf_subsets=[ diff --git a/mteb/tasks/aggregated_tasks/multilingual/sts_benchmark_multilingual_visual_sts.py b/mteb/tasks/aggregated_tasks/multilingual/sts_benchmark_multilingual_visual_sts.py index 2a8aced1cf..516f0e16d8 100644 --- a/mteb/tasks/aggregated_tasks/multilingual/sts_benchmark_multilingual_visual_sts.py +++ b/mteb/tasks/aggregated_tasks/multilingual/sts_benchmark_multilingual_visual_sts.py @@ -1,10 +1,10 @@ -from mteb.abstasks.abstask import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.sts.multilingual.sts_benchmark_multilingual_visual_sts import ( STSBenchmarkMultilingualVisualSTS, ) -task_list_multi: list[AbsTask] = [ +task_list_multi = [ STSBenchmarkMultilingualVisualSTS().filter_languages( languages=[ "deu", diff --git a/mteb/tasks/aggregated_tasks/nld/cqadupstack_nl_retrieval.py b/mteb/tasks/aggregated_tasks/nld/cqadupstack_nl_retrieval.py index 2d3ce4c2ba..2a3570ce6a 100644 --- a/mteb/tasks/aggregated_tasks/nld/cqadupstack_nl_retrieval.py +++ b/mteb/tasks/aggregated_tasks/nld/cqadupstack_nl_retrieval.py @@ -1,5 +1,5 @@ -from mteb.abstasks import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.retrieval import ( CQADupstackAndroidNLRetrieval, CQADupstackEnglishNLRetrieval, @@ -15,7 +15,7 @@ CQADupstackWordpressNLRetrieval, ) -task_list_cqa: list[AbsTask] = [ +task_list_cqa = [ CQADupstackAndroidNLRetrieval(), CQADupstackEnglishNLRetrieval(), CQADupstackGamingNLRetrieval(), diff --git a/mteb/tasks/aggregated_tasks/pol/cqadupstack_retrieval_pl.py b/mteb/tasks/aggregated_tasks/pol/cqadupstack_retrieval_pl.py index a9f05a94e5..50151b8ac2 100644 --- a/mteb/tasks/aggregated_tasks/pol/cqadupstack_retrieval_pl.py +++ b/mteb/tasks/aggregated_tasks/pol/cqadupstack_retrieval_pl.py @@ -1,5 +1,5 @@ -from mteb.abstasks import AbsTask -from mteb.abstasks.aggregated_task import AbsTaskAggregate, AggregateTaskMetadata +from mteb.abstasks.aggregate_task_metadata import AggregateTaskMetadata +from mteb.abstasks.aggregated_task import AbsTaskAggregate from mteb.tasks.retrieval.pol.cqadupstack_pl_retrieval import ( CQADupstackAndroidRetrievalPL, CQADupstackEnglishRetrievalPL, @@ -15,7 +15,7 @@ CQADupstackWordpressRetrievalPL, ) -task_list_cqa: list[AbsTask] = [ +task_list_cqa = [ CQADupstackAndroidRetrievalPL(), CQADupstackEnglishRetrievalPL(), CQADupstackGamingRetrievalPL(), diff --git a/mteb/tasks/clustering/nob/snl_clustering.py b/mteb/tasks/clustering/nob/snl_clustering.py index 08d84e5a5d..532beb5b43 100644 --- a/mteb/tasks/clustering/nob/snl_clustering.py +++ b/mteb/tasks/clustering/nob/snl_clustering.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import random -from collections.abc import Iterable from itertools import islice -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar import datasets from mteb.abstasks.clustering_legacy import AbsTaskClusteringLegacy from mteb.abstasks.task_metadata import TaskMetadata +if TYPE_CHECKING: + from collections.abc import Iterable + + T = TypeVar("T") diff --git a/mteb/tasks/clustering/nob/vg_clustering.py b/mteb/tasks/clustering/nob/vg_clustering.py index 3a811362a1..b88b290ea1 100644 --- a/mteb/tasks/clustering/nob/vg_clustering.py +++ b/mteb/tasks/clustering/nob/vg_clustering.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import random -from collections.abc import Iterable from itertools import islice -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar import datasets from mteb.abstasks.clustering_legacy import AbsTaskClusteringLegacy from mteb.abstasks.task_metadata import TaskMetadata +if TYPE_CHECKING: + from collections.abc import Iterable + + T = TypeVar("T") diff --git a/mteb/tasks/retrieval/eng/limit_retrieval.py b/mteb/tasks/retrieval/eng/limit_retrieval.py index 1183018bbc..70a3cf9e61 100644 --- a/mteb/tasks/retrieval/eng/limit_retrieval.py +++ b/mteb/tasks/retrieval/eng/limit_retrieval.py @@ -1,8 +1,13 @@ -from collections.abc import Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING from mteb.abstasks.retrieval import AbsTaskRetrieval from mteb.abstasks.task_metadata import TaskMetadata +if TYPE_CHECKING: + from collections.abc import Sequence + _CITATION = """ @misc{weller2025theoreticallimit, archiveprefix = {arXiv}, diff --git a/mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py b/mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py index afe91e99ee..c58b6b718e 100644 --- a/mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +++ b/mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py @@ -30,15 +30,15 @@ def load_ruscibench_data( for lang in langs: lang_corpus = cast( - datasets.Dataset, + "datasets.Dataset", datasets.load_dataset(path, f"corpus-{lang}", revision=revision), )["corpus"] lang_queries = cast( - datasets.Dataset, + "datasets.Dataset", datasets.load_dataset(path, f"queries-{lang}", revision=revision), )["queries"] lang_qrels = cast( - datasets.Dataset, + "datasets.Dataset", datasets.load_dataset(path, f"{lang}", revision=revision), )["test"] corpus[lang] = { diff --git a/mteb/types/_encoder_io.py b/mteb/types/_encoder_io.py index 809c82df3f..72b787c848 100644 --- a/mteb/types/_encoder_io.py +++ b/mteb/types/_encoder_io.py @@ -7,10 +7,10 @@ import numpy as np import torch from datasets import Dataset -from typing_extensions import NotRequired if TYPE_CHECKING: from PIL import Image + from typing_extensions import NotRequired class EncodeKwargs(TypedDict): diff --git a/mteb/types/statistics.py b/mteb/types/statistics.py index 97737c387c..e474b46a17 100644 --- a/mteb/types/statistics.py +++ b/mteb/types/statistics.py @@ -1,6 +1,13 @@ -from typing_extensions import NotRequired, TypedDict +from __future__ import annotations -from mteb.types import HFSubset +from typing import TYPE_CHECKING + +from typing_extensions import TypedDict + +if TYPE_CHECKING: + from typing_extensions import NotRequired + + from mteb.types import HFSubset class SplitDescriptiveStatistics(TypedDict): diff --git a/pyproject.toml b/pyproject.toml index 71b3869061..d5bf059b6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,8 +238,9 @@ select = [ "PTH", # use pathlib "TID", # tidy-imports "D", # pydocstyle - "PGH", # pygrep-hooks Use specific rule codes when ignoring type issues - "LOG", # logging + "PGH", # pygrep-hooks Use specific rule codes when ignoring type issues + "LOG", # logging + "TC", # type-checking ] ignore = [ @@ -252,14 +253,29 @@ ignore = [ "D415", # First line should end with a period "C408", # don't use unecc. collection call, e.g. dict over {} ] +future-annotations = true # For TC rules +typing-modules = ["typing_extensions", "mteb.types", "collections.abc"] + [tool.ruff.lint.per-file-ignores] "scripts/*" = ["ALL"] -"docs/*" = ["D", "DOC"] -"tests/*" = ["RUF012", "D", "DOC"] +"docs/*" = [ + "D", + "DOC", +] +"tests/*" = [ + "RUF012", + "D", + "DOC", + "TC", +] "mteb/tasks/*__init__.py" = ["F403"] # undefined import `from .lang import *` "mteb/tasks/*" = ["D", "DOC"] "mteb/models/model_implementations/*" = ["D", "DOC"] +"mteb/benchmarks/benchmark.py" = [ + # pydantic would try to validate fields in Benchmark class + "TC", +] [tool.ruff.lint.flake8-implicit-str-concat] allow-multiline = false @@ -278,6 +294,13 @@ convention = "google" mypy-init-return = true suppress-none-returning = true +[tool.ruff.lint.flake8-type-checking] +strict = true +runtime-evaluated-base-classes = [ + "pydantic.BaseModel", + "mteb.abstasks.task_metadata.TaskMetadata", # https://github.com/astral-sh/ruff/issues/7866 +] + [tool.semantic_release] branch = "main" version_toml = ["pyproject.toml:project.version"] diff --git a/tests/mock_models.py b/tests/mock_models.py index 53ac737381..71214b124b 100644 --- a/tests/mock_models.py +++ b/tests/mock_models.py @@ -1,7 +1,9 @@ """Mock models to be used for testing""" +from __future__ import annotations + from types import SimpleNamespace -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np import torch @@ -13,7 +15,10 @@ from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta from mteb.models.sentence_transformer_wrapper import SentenceTransformerEncoderWrapper -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType + +if TYPE_CHECKING: + from mteb.types import Array, BatchedInput empty_metadata_kwargs = dict( loader=None, diff --git a/tests/test_integrations/test_encode_args_passed.py b/tests/test_integrations/test_encode_args_passed.py index 18c8685771..9d617279e9 100644 --- a/tests/test_integrations/test_encode_args_passed.py +++ b/tests/test_integrations/test_encode_args_passed.py @@ -4,7 +4,7 @@ import logging from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import pytest @@ -15,9 +15,12 @@ from mteb.abstasks.task_metadata import TaskMetadata from mteb.models import ModelMeta from mteb.models.abs_encoder import AbsEncoder -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType from tests.task_grid import MOCK_MIEB_TASK_GRID, MOCK_TASK_TEST_GRID +if TYPE_CHECKING: + from mteb.types import Array, BatchedInput + logging.basicConfig(level=logging.INFO) diff --git a/tests/test_models/test_cached_embedding_wrapper.py b/tests/test_models/test_cached_embedding_wrapper.py index 4a82e534fc..f00b76901a 100644 --- a/tests/test_models/test_cached_embedding_wrapper.py +++ b/tests/test_models/test_cached_embedding_wrapper.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import shutil from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import pytest @@ -16,9 +18,12 @@ from mteb.models.cache_wrappers.cache_wrapper import CachedEmbeddingWrapper from mteb.models.model_implementations.random_baseline import RandomEncoderBaseline from mteb.models.models_protocols import EncoderProtocol -from mteb.types import Array, BatchedInput, PromptType +from mteb.types import PromptType from tests.mock_tasks import MockMultiChoiceTask, MockRetrievalTask +if TYPE_CHECKING: + from mteb.types import Array, BatchedInput + class DummyModel(RandomEncoderBaseline): call_count = 0 diff --git a/tests/test_result_cache.py b/tests/test_result_cache.py index 27a983e647..39cd4ddc65 100644 --- a/tests/test_result_cache.py +++ b/tests/test_result_cache.py @@ -149,7 +149,7 @@ def test_filter_with_modelmeta(): model_name = model_meta.model_name_as_path() model_revision_1 = model_meta.revision - model_revision_1 = cast(str, model_revision_1) + model_revision_1 = cast("str", model_revision_1) sample_paths = [ base / model_name / model_revision_1 / "task1.json", base / model_name / model_revision_1 / "task2.json", @@ -177,7 +177,7 @@ def test_filter_with_string_models(): model_name = model_meta.model_name_as_path() model_revision_1 = model_meta.revision - model_revision_1 = cast(str, model_revision_1) + model_revision_1 = cast("str", model_revision_1) sample_paths = [ base / model_name / model_revision_1 / "task1.json", base / model_name / model_revision_1 / "task2.json", diff --git a/tests/test_tasks/test_task_quality.py b/tests/test_tasks/test_task_quality.py index 90180e8b28..6d8fcd445f 100644 --- a/tests/test_tasks/test_task_quality.py +++ b/tests/test_tasks/test_task_quality.py @@ -284,7 +284,7 @@ def _split_quality( num_samples = split_stats["num_samples"] text_stats = split_stats.get("text_statistics", None) if text_stats: - text_stats = cast(TextStatistics, text_stats) + text_stats = cast("TextStatistics", text_stats) unique_texts = text_stats["unique_texts"] # Note: there could be cases where a dataset