diff --git a/src/seb/cli/table.py b/src/seb/cli/table.py index ddc8fea1..1c5da423 100644 --- a/src/seb/cli/table.py +++ b/src/seb/cli/table.py @@ -5,7 +5,7 @@ from rich.table import Table import seb -from seb.types import Language +from seb.interfaces.language import Language def get_main_score(task: seb.TaskResult, langs: Optional[list[Language]]) -> float: diff --git a/src/seb/interfaces/language.py b/src/seb/interfaces/language.py new file mode 100644 index 00000000..37e1b620 --- /dev/null +++ b/src/seb/interfaces/language.py @@ -0,0 +1,4 @@ +from typing import Literal + +Language = Literal["da", "nb", "nn", "sv", "da-bornholm", "is", "fo"] +languages_in_seb: list[Language] = ["da", "nb", "nn", "sv", "da-bornholm", "is", "fo"] diff --git a/src/seb/interfaces/model.py b/src/seb/interfaces/model.py index 201f9a31..5ad07a28 100644 --- a/src/seb/interfaces/model.py +++ b/src/seb/interfaces/model.py @@ -2,9 +2,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable +from numpy.typing import ArrayLike from pydantic import BaseModel -from ..types import ArrayLike +from seb.interfaces.language import Language if TYPE_CHECKING: from .task import Task @@ -48,7 +49,7 @@ class ModelMeta(BaseModel): description: Optional[str] = None huggingface_name: Optional[str] = None reference: Optional[str] = None - languages: list[str] = [] + languages: list[Language] = [] open_source: bool = False embedding_size: Optional[int] = None diff --git a/src/seb/interfaces/mteb_task.py b/src/seb/interfaces/mteb_task.py index f34ba97e..9d306bb9 100644 --- a/src/seb/interfaces/mteb_task.py +++ b/src/seb/interfaces/mteb_task.py @@ -5,9 +5,9 @@ from datasets import DatasetDict, concatenate_datasets from mteb import AbsTask from mteb import __version__ as mteb_version +from numpy.typing import ArrayLike from ..result_dataclasses import TaskResult -from ..types import ArrayLike, Language from .model import Encoder from .task import DescriptiveDatasetStats, Task @@ -88,7 +88,9 @@ def evaluate(self, model: Encoder) -> TaskResult: scores = scores.get(split, scores) score_is_nested = isinstance(scores[next(iter(scores.keys()))], dict) if not score_is_nested: - _scores: dict[str, dict[str, Union[float, str]]] = {lang: scores for lang in self.languages} + _scores: dict[str, dict[str, Union[float, str]]] = { + lang: scores for lang in self.languages + } scores = _scores task_result = TaskResult( diff --git a/src/seb/interfaces/task.py b/src/seb/interfaces/task.py index dffff59c..5d578060 100644 --- a/src/seb/interfaces/task.py +++ b/src/seb/interfaces/task.py @@ -1,11 +1,40 @@ -from typing import Protocol, runtime_checkable +from typing import Literal, Protocol, TypedDict, runtime_checkable from attr import dataclass +from seb.interfaces.language import Language + from ..result_dataclasses import TaskResult -from ..types import DescriptiveDatasetStats, Domain, Language, TaskType from .model import Encoder +Domain = Literal[ + "social", + "poetry", + "wiki", + "fiction", + "non-fiction", + "web", + "legal", + "news", + "academic", + "spoken", + "reviews", + "blog", + "medical", + "government", + "bible", +] + +TaskType = Literal[ + "Classification", "Retrieval", "STS", "BitextMining", "Clustering", "Speed" +] + + +class DescriptiveDatasetStats(TypedDict): + mean_document_length: float + std_document_length: float + num_documents: int + @runtime_checkable class Task(Protocol): diff --git a/src/seb/registered_models/e5_mistral.py b/src/seb/registered_models/e5_mistral.py index e516c868..4dafe5b7 100644 --- a/src/seb/registered_models/e5_mistral.py +++ b/src/seb/registered_models/e5_mistral.py @@ -4,14 +4,13 @@ import torch import torch.nn.functional as F +from numpy.typing import ArrayLike from torch import Tensor from transformers import AutoModel, AutoTokenizer, BatchEncoding from seb import models from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta -from ..types import ArrayLike - T = TypeVar("T") @@ -31,7 +30,9 @@ def __init__(self): self.load_model() def load_model(self): - self.tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-mistral-7b-instruct") + self.tokenizer = AutoTokenizer.from_pretrained( + "intfloat/e5-mistral-7b-instruct" + ) self.model = AutoModel.from_pretrained("intfloat/e5-mistral-7b-instruct") def preprocess(self, sentences: Sequence[str]) -> BatchEncoding: @@ -52,7 +53,9 @@ def preprocess(self, sentences: Sequence[str]) -> BatchEncoding: [*input_ids, self.tokenizer.eos_token_id] for input_ids in batch_dict["input_ids"] # type: ignore ] - batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors="pt") + batch_dict = self.tokenizer.pad( + batch_dict, padding=True, return_attention_mask=True, return_tensors="pt" + ) return batch_dict diff --git a/src/seb/registered_models/hf_models.py b/src/seb/registered_models/hf_models.py index 77fc9b48..4d7d5838 100644 --- a/src/seb/registered_models/hf_models.py +++ b/src/seb/registered_models/hf_models.py @@ -5,14 +5,13 @@ from functools import partial from typing import Any, Optional +from numpy.typing import ArrayLike from sentence_transformers import SentenceTransformer from seb.interfaces.model import EmbeddingModel, ModelMeta from seb.interfaces.task import Task from seb.registries import models -from ..types import ArrayLike - def silence_warnings_from_sentence_transformers(): from sentence_transformers.SentenceTransformer import logger diff --git a/src/seb/result_dataclasses.py b/src/seb/result_dataclasses.py index 3ab524d4..38fa520f 100644 --- a/src/seb/result_dataclasses.py +++ b/src/seb/result_dataclasses.py @@ -7,8 +7,8 @@ import numpy as np from pydantic import BaseModel +from .interfaces.language import Language from .interfaces.model import ModelMeta -from .types import Language class TaskResult(BaseModel): @@ -28,7 +28,9 @@ class TaskResult(BaseModel): task_description: str task_version: str time_of_run: datetime - scores: dict[str, dict[str, Union[float, str]]] # {language: {"metric": value}}. + scores: dict[ + Language, dict[str, Union[float, str]] + ] # {language: {"metric": value}}. main_score: str def get_main_score(self, lang: Optional[Iterable[str]] = None) -> float: @@ -51,7 +53,7 @@ def get_main_score(self, lang: Optional[Iterable[str]] = None) -> float: return sum(main_scores) / len(main_scores) @property - def languages(self) -> list[str]: + def languages(self) -> list[Language]: """ Returns the languages of the task. """ @@ -153,7 +155,9 @@ def to_disk(self, path: Path) -> None: Write task results to a path. """ if path.is_file(): - raise ValueError("Can't save BenchmarkResults to a file. Path must be a directory.") + raise ValueError( + "Can't save BenchmarkResults to a file. Path must be a directory." + ) path.mkdir(parents=True, exist_ok=True) for task_result in self.task_results: if isinstance(task_result, TaskResult): @@ -170,7 +174,9 @@ def from_disk(cls, path: Path) -> "BenchmarkResults": Load task results from a path. """ if not path.is_dir(): - raise ValueError("Can't load BenchmarkResults from path: {path}. Path must be a directory.") + raise ValueError( + "Can't load BenchmarkResults from path: {path}. Path must be a directory." + ) task_results = [] for file in path.glob("*.json"): if file.stem == "meta": diff --git a/src/seb/types.py b/src/seb/types.py deleted file mode 100644 index a9f6a0f7..00000000 --- a/src/seb/types.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Literal, TypedDict, Union - -from numpy import ndarray -from torch import Tensor - -ArrayLike = Union[ndarray, Tensor] - - -Domain = Literal[ - "social", - "poetry", - "wiki", - "fiction", - "non-fiction", - "web", - "legal", - "news", - "academic", - "spoken", - "reviews", - "blog", - "medical", - "government", - "bible", -] - -Language = Literal["da", "nb", "nn", "sv", "da-bornholm", "is", "fo"] -languages_in_seb: list[Language] = ["da", "nb", "nn", "sv", "da-bornholm", "is", "fo"] - -TaskType = Literal["Classification", "Retrieval", "STS", "BitextMining", "Clustering", "Speed"] - - -class DescriptiveDatasetStats(TypedDict): - mean_document_length: float - std_document_length: float - num_documents: int