Skip to content

Commit

Permalink
Moved task types to task interface and deleted types module
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Jan 23, 2024
1 parent 74fcf43 commit 2f1adf1
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/seb/cli/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rich.table import Table

import seb
from seb.types import Language
from seb.interfaces.language import Language


def get_main_score(task: seb.TaskResult, langs: Optional[list[Language]]) -> float:
Expand Down
4 changes: 4 additions & 0 deletions src/seb/interfaces/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Literal

Language = Literal["da", "nb", "nn", "sv", "da-bornholm", "is", "fo"]
languages_in_seb: list[Language] = ["da", "nb", "nn", "sv", "da-bornholm", "is", "fo"]
5 changes: 3 additions & 2 deletions src/seb/interfaces/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable

from numpy.typing import ArrayLike
from pydantic import BaseModel

from ..types import ArrayLike
from seb.interfaces.language import Language

if TYPE_CHECKING:
from .task import Task
Expand Down Expand Up @@ -48,7 +49,7 @@ class ModelMeta(BaseModel):
description: Optional[str] = None
huggingface_name: Optional[str] = None
reference: Optional[str] = None
languages: list[str] = []
languages: list[Language] = []
open_source: bool = False
embedding_size: Optional[int] = None

Expand Down
6 changes: 4 additions & 2 deletions src/seb/interfaces/mteb_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from datasets import DatasetDict, concatenate_datasets
from mteb import AbsTask
from mteb import __version__ as mteb_version
from numpy.typing import ArrayLike

from ..result_dataclasses import TaskResult
from ..types import ArrayLike, Language
from .model import Encoder
from .task import DescriptiveDatasetStats, Task

Expand Down Expand Up @@ -88,7 +88,9 @@ def evaluate(self, model: Encoder) -> TaskResult:
scores = scores.get(split, scores)
score_is_nested = isinstance(scores[next(iter(scores.keys()))], dict)
if not score_is_nested:
_scores: dict[str, dict[str, Union[float, str]]] = {lang: scores for lang in self.languages}
_scores: dict[str, dict[str, Union[float, str]]] = {
lang: scores for lang in self.languages
}
scores = _scores

task_result = TaskResult(
Expand Down
33 changes: 31 additions & 2 deletions src/seb/interfaces/task.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
from typing import Protocol, runtime_checkable
from typing import Literal, Protocol, TypedDict, runtime_checkable

from attr import dataclass

from seb.interfaces.language import Language

from ..result_dataclasses import TaskResult
from ..types import DescriptiveDatasetStats, Domain, Language, TaskType
from .model import Encoder

Domain = Literal[
"social",
"poetry",
"wiki",
"fiction",
"non-fiction",
"web",
"legal",
"news",
"academic",
"spoken",
"reviews",
"blog",
"medical",
"government",
"bible",
]

TaskType = Literal[
"Classification", "Retrieval", "STS", "BitextMining", "Clustering", "Speed"
]


class DescriptiveDatasetStats(TypedDict):
mean_document_length: float
std_document_length: float
num_documents: int


@runtime_checkable
class Task(Protocol):
Expand Down
11 changes: 7 additions & 4 deletions src/seb/registered_models/e5_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import torch
import torch.nn.functional as F
from numpy.typing import ArrayLike
from torch import Tensor
from transformers import AutoModel, AutoTokenizer, BatchEncoding

from seb import models
from seb.interfaces.model import EmbeddingModel, Encoder, ModelMeta

from ..types import ArrayLike

T = TypeVar("T")


Expand All @@ -31,7 +30,9 @@ def __init__(self):
self.load_model()

def load_model(self):
self.tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-mistral-7b-instruct")
self.tokenizer = AutoTokenizer.from_pretrained(
"intfloat/e5-mistral-7b-instruct"
)
self.model = AutoModel.from_pretrained("intfloat/e5-mistral-7b-instruct")

def preprocess(self, sentences: Sequence[str]) -> BatchEncoding:
Expand All @@ -52,7 +53,9 @@ def preprocess(self, sentences: Sequence[str]) -> BatchEncoding:
[*input_ids, self.tokenizer.eos_token_id]
for input_ids in batch_dict["input_ids"] # type: ignore
]
batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors="pt")
batch_dict = self.tokenizer.pad(
batch_dict, padding=True, return_attention_mask=True, return_tensors="pt"
)

return batch_dict

Expand Down
3 changes: 1 addition & 2 deletions src/seb/registered_models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from functools import partial
from typing import Any, Optional

from numpy.typing import ArrayLike
from sentence_transformers import SentenceTransformer

from seb.interfaces.model import EmbeddingModel, ModelMeta
from seb.interfaces.task import Task
from seb.registries import models

from ..types import ArrayLike


def silence_warnings_from_sentence_transformers():
from sentence_transformers.SentenceTransformer import logger
Expand Down
16 changes: 11 additions & 5 deletions src/seb/result_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
from pydantic import BaseModel

from .interfaces.language import Language
from .interfaces.model import ModelMeta
from .types import Language


class TaskResult(BaseModel):
Expand All @@ -28,7 +28,9 @@ class TaskResult(BaseModel):
task_description: str
task_version: str
time_of_run: datetime
scores: dict[str, dict[str, Union[float, str]]] # {language: {"metric": value}}.
scores: dict[
Language, dict[str, Union[float, str]]
] # {language: {"metric": value}}.
main_score: str

def get_main_score(self, lang: Optional[Iterable[str]] = None) -> float:
Expand All @@ -51,7 +53,7 @@ def get_main_score(self, lang: Optional[Iterable[str]] = None) -> float:
return sum(main_scores) / len(main_scores)

@property
def languages(self) -> list[str]:
def languages(self) -> list[Language]:
"""
Returns the languages of the task.
"""
Expand Down Expand Up @@ -153,7 +155,9 @@ def to_disk(self, path: Path) -> None:
Write task results to a path.
"""
if path.is_file():
raise ValueError("Can't save BenchmarkResults to a file. Path must be a directory.")
raise ValueError(
"Can't save BenchmarkResults to a file. Path must be a directory."
)
path.mkdir(parents=True, exist_ok=True)
for task_result in self.task_results:
if isinstance(task_result, TaskResult):
Expand All @@ -170,7 +174,9 @@ def from_disk(cls, path: Path) -> "BenchmarkResults":
Load task results from a path.
"""
if not path.is_dir():
raise ValueError("Can't load BenchmarkResults from path: {path}. Path must be a directory.")
raise ValueError(
"Can't load BenchmarkResults from path: {path}. Path must be a directory."
)
task_results = []
for file in path.glob("*.json"):
if file.stem == "meta":
Expand Down
36 changes: 0 additions & 36 deletions src/seb/types.py

This file was deleted.

0 comments on commit 2f1adf1

Please sign in to comment.