diff --git a/docs/contributing/adding_a_model.md b/docs/contributing/adding_a_model.md index 7f14102f65..e5911062b5 100644 --- a/docs/contributing/adding_a_model.md +++ b/docs/contributing/adding_a_model.md @@ -40,6 +40,36 @@ Typically, it only requires that you fill in metadata about the model and add it This works for all [Sentence Transformers](https://sbert.net) compatible models. Once filled out, you can submit your model to `mteb` by submitting a PR. +You can generate it automatically by using: + +=== "General model from hub" + ```python + from mteb.models import ModelMeta + + meta = ModelMeta.from_hub("Qwen/Qwen3-Embedding-0.6B") + print(meta.to_python()) + ``` + +=== "For Sentence transformers model" + ```python + from mteb.models import ModelMeta + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B", device="cpu") + meta = ModelMeta.from_sentence_transformer_model(model) + print(meta.to_python()) + ``` + +=== "For CrossEncoder" + ```python + from mteb.models import ModelMeta + from sentence_transformers import CrossEncoder + + model = SentenceTransformer("Qwen/Qwen3-Reranker-0.6B", device="cpu") + meta = ModelMeta.from_cross_encoder(model) + print(meta.to_python()) + ``` + ### Calculating the Memory Usage diff --git a/mteb/deprecated_evaluator.py b/mteb/deprecated_evaluator.py index c3f7712cf3..c6c419a9c6 100644 --- a/mteb/deprecated_evaluator.py +++ b/mteb/deprecated_evaluator.py @@ -13,21 +13,11 @@ from time import time from typing import TYPE_CHECKING, Any -from mteb.abstasks.task_metadata import TaskCategory, TaskType -from mteb.models.get_model_meta import ( - _model_meta_from_cross_encoder, - _model_meta_from_sentence_transformers, -) - -if sys.version_info >= (3, 13): - from warnings import deprecated -else: - from typing_extensions import deprecated - import datasets import mteb from mteb.abstasks import AbsTask +from mteb.abstasks.task_metadata import TaskCategory, TaskType from mteb.benchmarks import Benchmark from mteb.models import ( CrossEncoderWrapper, @@ -39,6 +29,11 @@ from mteb.results import TaskResult from mteb.types import ScoresDict +if sys.version_info >= (3, 13): + from warnings import deprecated +else: + from typing_extensions import deprecated + if TYPE_CHECKING: from sentence_transformers import CrossEncoder, SentenceTransformer @@ -669,9 +664,9 @@ def _get_model_meta(model: EncoderProtocol) -> ModelMeta: from sentence_transformers import CrossEncoder, SentenceTransformer if isinstance(model, CrossEncoder): - meta = _model_meta_from_cross_encoder(model) + meta = ModelMeta.from_cross_encoder(model) elif isinstance(model, SentenceTransformer): - meta = _model_meta_from_sentence_transformers(model) + meta = ModelMeta.from_sentence_transformer_model(model) else: meta = ModelMeta( loader=None, diff --git a/mteb/evaluate.py b/mteb/evaluate.py index 35dcaa3e74..41e730b88d 100644 --- a/mteb/evaluate.py +++ b/mteb/evaluate.py @@ -2,7 +2,6 @@ import logging from collections.abc import Iterable -from copy import deepcopy from pathlib import Path from time import time from typing import TYPE_CHECKING, Any, cast @@ -53,36 +52,6 @@ class OverwriteStrategy(HelpfulStrEnum): ONLY_CACHE = "only-cache" -_empty_model_meta = ModelMeta( - loader=None, - name=None, - revision=None, - release_date=None, - languages=None, - framework=[], - similarity_fn_name=None, - n_parameters=None, - memory_usage_mb=None, - max_tokens=None, - embed_dim=None, - license=None, - open_weights=None, - public_training_code=None, - public_training_data=None, - use_instructions=None, - training_datasets=None, - modalities=[], -) - - -def _create_empty_model_meta() -> ModelMeta: - logger.warning("Model metadata is missing. Using empty metadata.") - meta = deepcopy(_empty_model_meta) - meta.revision = "no_revision_available" - meta.name = "no_model_name_available" - return meta - - def _sanitize_model( model: ModelMeta | MTEBModels | SentenceTransformer | CrossEncoder, ) -> tuple[MTEBModels | ModelMeta, ModelMeta, ModelName, Revision]: @@ -101,9 +70,9 @@ def _sanitize_model( elif hasattr(model, "mteb_model_meta"): meta = model.mteb_model_meta # type: ignore[attr-defined] if not isinstance(meta, ModelMeta): - meta = _create_empty_model_meta() + meta = ModelMeta.from_hub(None) else: - meta = _create_empty_model_meta() if not isinstance(model, ModelMeta) else model + meta = ModelMeta.from_hub(None) if not isinstance(model, ModelMeta) else model model_name = cast(str, meta.name) model_revision = cast(str, meta.revision) diff --git a/mteb/models/get_model_meta.py b/mteb/models/get_model_meta.py index 2034a0478b..37a52b56b4 100644 --- a/mteb/models/get_model_meta.py +++ b/mteb/models/get_model_meta.py @@ -1,26 +1,15 @@ -from __future__ import annotations - import difflib import logging -import warnings from collections.abc import Iterable -from typing import TYPE_CHECKING, Any - -from huggingface_hub import ModelCard -from huggingface_hub.errors import RepositoryNotFoundError +from typing import Any from mteb.abstasks import AbsTask from mteb.models import ( - CrossEncoderWrapper, ModelMeta, MTEBModels, - sentence_transformers_loader, ) from mteb.models.model_implementations import MODEL_REGISTRY -if TYPE_CHECKING: - from sentence_transformers import CrossEncoder, SentenceTransformer - logger = logging.getLogger(__name__) @@ -101,24 +90,9 @@ def get_model( Returns: A model object """ - from sentence_transformers import CrossEncoder, SentenceTransformer - meta = get_model_meta(model_name, revision) model = meta.load_model(**kwargs) - # If revision not available in the modelmeta, try to extract it from sentence-transformers - if hasattr(model, "model") and isinstance(model.model, SentenceTransformer): # type: ignore - _meta = _model_meta_from_sentence_transformers(model.model) # type: ignore - if meta.revision is None: - meta.revision = _meta.revision if _meta.revision else meta.revision - if not meta.similarity_fn_name: - meta.similarity_fn_name = _meta.similarity_fn_name - - elif isinstance(model, CrossEncoder): - _meta = _model_meta_from_cross_encoder(model.model) - if meta.revision is None: - meta.revision = _meta.revision if _meta.revision else meta.revision - model.mteb_model_meta = meta # type: ignore return model @@ -148,12 +122,8 @@ def get_model_meta( logger.info( "Model not found in model registry. Attempting to extract metadata by loading the model ({model_name}) using HuggingFace." ) - try: - meta = _model_meta_from_hf_hub(model_name) - meta.revision = revision - return meta - except RepositoryNotFoundError: - pass + meta = ModelMeta.from_hub(model_name, revision) + return meta not_found_msg = f"Model '{model_name}' not found in MTEB registry" not_found_msg += " nor on the Huggingface Hub." if fetch_from_hf else "." @@ -171,96 +141,3 @@ def get_model_meta( suggestion = f" Did you mean: '{close_matches[0]}'?" raise KeyError(not_found_msg + suggestion) - - -def _model_meta_from_hf_hub(model_name: str) -> ModelMeta: - card = ModelCard.load(model_name) - card_data = card.data.to_dict() - frameworks = ["PyTorch"] - loader = None - if card_data.get("library_name", None) == "sentence-transformers": - frameworks.append("Sentence Transformers") - loader = sentence_transformers_loader - else: - msg = ( - "Model library not recognized, defaulting to Sentence Transformers loader." - ) - logger.warning(msg) - warnings.warn(msg) - loader = sentence_transformers_loader - - revision = card_data.get("base_model_revision", None) - license = card_data.get("license", None) - meta = ModelMeta( - loader=loader, - name=model_name, - revision=revision, - release_date=ModelMeta.fetch_release_date(model_name), - languages=None, - license=license, - framework=frameworks, # type: ignore - training_datasets=None, - similarity_fn_name=None, - n_parameters=None, - memory_usage_mb=None, - max_tokens=None, - embed_dim=None, - open_weights=True, - public_training_code=None, - public_training_data=None, - use_instructions=None, - ) - return meta - - -def _model_meta_from_cross_encoder(model: CrossEncoder) -> ModelMeta: - model_name = model.model.name_or_path - meta = ModelMeta( - loader=CrossEncoderWrapper, - name=model_name, - revision=model.config._commit_hash, - release_date=ModelMeta.fetch_release_date(model_name), - languages=None, - framework=["Sentence Transformers"], - similarity_fn_name=None, - n_parameters=None, - memory_usage_mb=None, - max_tokens=None, - embed_dim=None, - license=None, - open_weights=True, - public_training_code=None, - public_training_data=None, - use_instructions=None, - training_datasets=None, - ) - return meta - - -def _model_meta_from_sentence_transformers(model: SentenceTransformer) -> ModelMeta: - name: str | None = ( - model.model_card_data.model_name - if model.model_card_data.model_name - else model.model_card_data.base_model - ) - embeddings_dim = model.get_sentence_embedding_dimension() - meta = ModelMeta( - loader=sentence_transformers_loader, - name=name, - revision=model.model_card_data.base_model_revision, - release_date=ModelMeta.fetch_release_date(name) if name else None, - languages=None, - framework=["Sentence Transformers"], - similarity_fn_name=None, - n_parameters=None, - memory_usage_mb=None, - max_tokens=None, - embed_dim=embeddings_dim, - license=None, - open_weights=True, - public_training_code=None, - public_training_data=None, - use_instructions=None, - training_datasets=None, - ) - return meta diff --git a/mteb/models/model_meta.py b/mteb/models/model_meta.py index 74d8cb7751..a4a657a843 100644 --- a/mteb/models/model_meta.py +++ b/mteb/models/model_meta.py @@ -1,27 +1,46 @@ +from __future__ import annotations + +import json import logging +import warnings from collections.abc import Callable, Sequence from dataclasses import field from enum import Enum from functools import partial +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, cast -from huggingface_hub import get_safetensors_metadata, list_repo_commits +from huggingface_hub import ( + GitCommitInfo, + ModelCard, + ModelCardData, + get_safetensors_metadata, + hf_hub_download, + list_repo_commits, + repo_exists, +) from huggingface_hub.errors import ( + EntryNotFoundError, GatedRepoError, NotASafetensorsRepoError, RepositoryNotFoundError, SafetensorsParsingError, ) from pydantic import BaseModel, ConfigDict, field_validator +from transformers import AutoConfig +from typing_extensions import Self +from mteb._helpful_enum import HelpfulStrEnum from mteb.languages import check_language_code +from mteb.models.models_protocols import EncoderProtocol, MTEBModels from mteb.types import ISOLanguageScript, Licenses, Modalities, StrDate, StrURL -from .models_protocols import EncoderProtocol, MTEBModels - if TYPE_CHECKING: + from sentence_transformers import CrossEncoder, SentenceTransformer + from mteb.abstasks import AbsTask + logger = logging.getLogger(__name__) FRAMEWORKS = Literal[ @@ -39,7 +58,7 @@ ] -class ScoringFunction(str, Enum): +class ScoringFunction(HelpfulStrEnum): """The scoring function used by the models.""" COSINE = "cosine" @@ -60,6 +79,9 @@ def _get_loader_name( return loader.__name__ +_SENTENCE_TRANSFORMER_LIB_NAME = "Sentence Transformers" + + class ModelMeta(BaseModel): """The model metadata object. @@ -214,9 +236,198 @@ def model_name_as_path(self) -> str: raise ValueError("Model name is not set") return self.name.replace("/", "__").replace(" ", "_") - def is_zero_shot_on( - self, tasks: Sequence["AbsTask"] | Sequence[str] - ) -> bool | None: + @classmethod + def _from_hub( + cls, + model_name: str | None, + revision: str | None = None, + compute_metadata: bool = True, + ) -> Self: + """Generates a ModelMeta from a HuggingFace model name. + + Args: + model_name: The HuggingFace model name. + revision: Revision of the model + compute_metadata: Add metadata based on model card + + Returns: + The generated ModelMeta. + """ + from mteb.models import sentence_transformers_loader + + loader = sentence_transformers_loader + frameworks: list[FRAMEWORKS] = ["PyTorch"] + model_license = None + reference = None + n_parameters = None + memory_usage_mb = None + release_date = None + embedding_dim = None + max_tokens = None + + if model_name and compute_metadata and repo_exists(model_name): + reference = "https://huggingface.co/" + model_name + card = ModelCard.load(model_name) + card_data: ModelCardData = card.data + try: + model_config = AutoConfig.from_pretrained(model_name) + except Exception as e: + # some models can't load AutoConfig (e.g. `average_word_embeddings_levy_dependency`) + model_config = None + logger.warning(f"Can't get configuration for {model_name}. Error: {e}") + + if ( + card_data.library_name == _SENTENCE_TRANSFORMER_LIB_NAME + or _SENTENCE_TRANSFORMER_LIB_NAME in card_data.tags + ): + frameworks.append(_SENTENCE_TRANSFORMER_LIB_NAME) + else: + msg = "Model library not recognized, defaulting to Sentence Transformers loader." + logger.warning(msg) + warnings.warn(msg) + + if revision is None: + revisions = _get_repo_commits(model_name, "model") + revision = revisions[0].commit_id if revisions else None + + release_date = cls.fetch_release_date(model_name) + model_license = card_data.license + n_parameters = cls._calculate_num_parameters_from_hub(model_name) + memory_usage_mb = cls._calculate_memory_usage_mb(model_name, n_parameters) + if model_config and hasattr(model_config, "hidden_size"): + embedding_dim = model_config.hidden_size + if model_config and hasattr(model_config, "max_position_embeddings"): + max_tokens = model_config.max_position_embeddings + + return cls( + loader=loader, + name=model_name or "no_model_name/available", + revision=revision or "no_revision_available", + reference=reference, + release_date=release_date, + languages=None, + license=model_license, + framework=frameworks, + training_datasets=None, + similarity_fn_name=None, + n_parameters=n_parameters, + memory_usage_mb=memory_usage_mb, + max_tokens=max_tokens, + embed_dim=embedding_dim, + open_weights=True, + public_training_code=None, + public_training_data=None, + use_instructions=None, + modalities=[], + ) + + @classmethod + def from_sentence_transformer_model( + cls, + model: SentenceTransformer, + revision: str | None = None, + compute_metadata: bool = True, + ) -> Self: + """Generates a ModelMeta from a SentenceTransformer model. + + Args: + model: SentenceTransformer model. + revision: Revision of the model + compute_metadata: Add metadata based on model card + + Returns: + The generated ModelMeta. + """ + name: str | None = ( + model.model_card_data.model_name + if model.model_card_data.model_name + else model.model_card_data.base_model + ) + meta = cls._from_hub(name, revision, compute_metadata) + if _SENTENCE_TRANSFORMER_LIB_NAME not in meta.framework: + meta.framework.append("Sentence Transformers") + meta.revision = model.model_card_data.base_model_revision or meta.revision + meta.max_tokens = model.max_seq_length + meta.embed_dim = model.get_sentence_embedding_dimension() + meta.similarity_fn_name = ScoringFunction.from_str(model.similarity_fn_name) + meta.modalities = ["text"] + return meta + + @classmethod + def from_hub( + cls, + model: str, + revision: str | None = None, + compute_metadata: bool = True, + ) -> Self: + """Generates a ModelMeta for model from HuggingFace hub. + + Args: + model: Name of the model from HuggingFace hub. For example, `intfloat/multilingual-e5-large` + revision: Revision of the model + compute_metadata: Add metadata based on model card + + Returns: + The generated ModelMeta. + """ + meta = cls._from_hub(model, revision, compute_metadata) + if _SENTENCE_TRANSFORMER_LIB_NAME not in meta.framework: + meta.framework.append("Sentence Transformers") + meta.modalities = ["text"] + + if model and compute_metadata and repo_exists(model): + # have max_seq_length field + sbert_config = _get_json_from_hub( + model, "sentence_bert_config.json", "model", revision=revision + ) + if sbert_config: + meta.max_tokens = ( + sbert_config.get("max_seq_length", None) or meta.max_tokens + ) + # have model type, similarity function fields + config_sbert = _get_json_from_hub( + model, "config_sentence_transformers.json", "model", revision=revision + ) + if ( + config_sbert is not None + and config_sbert.get("similarity_fn_name") is not None + ): + meta.similarity_fn_name = ScoringFunction.from_str( + config_sbert.get("similarity_fn_name") + ) + else: + meta.similarity_fn_name = ScoringFunction.COSINE + return meta + + @classmethod + def from_cross_encoder( + cls, + model: CrossEncoder, + revision: str | None = None, + compute_metadata: bool = True, + ) -> Self: + """Generates a ModelMeta from a CrossEncoder. + + Args: + model: The CrossEncoder model + revision: Revision of the model + compute_metadata: Add metadata based on model card + + Returns: + The generated ModelMeta + """ + from mteb.models import CrossEncoderWrapper + + meta = cls._from_hub(model.model.name_or_path, revision, compute_metadata) + if _SENTENCE_TRANSFORMER_LIB_NAME not in meta.framework: + meta.framework.append("Sentence Transformers") + meta.revision = model.config._commit_hash or meta.revision + meta.loader = CrossEncoderWrapper + meta.embed_dim = None + meta.modalities = ["text"] + return meta + + def is_zero_shot_on(self, tasks: Sequence[AbsTask] | Sequence[str]) -> bool | None: """Indicates whether the given model can be considered zero-shot or not on the given tasks. Returns: @@ -269,7 +480,7 @@ def get_training_datasets(self) -> set[str] | None: return return_dataset def zero_shot_percentage( - self, tasks: Sequence["AbsTask"] | Sequence[str] + self, tasks: Sequence[AbsTask] | Sequence[str] ) -> int | None: """Indicates how out-of-domain the selected tasks are for the given model. @@ -292,18 +503,38 @@ def zero_shot_percentage( perc_overlap = 100 * (len(overlap) / len(benchmark_datasets)) return int(100 - perc_overlap) - def calculate_memory_usage_mb(self) -> int | None: - """Calculates the memory usage (in FP32) of the model in MB. + @staticmethod + def _calculate_num_parameters_from_hub(model_name: str | None = None) -> int | None: + try: + safetensors_metadata = get_safetensors_metadata(model_name) + if len(safetensors_metadata.parameter_count) >= 0: + return sum(safetensors_metadata.parameter_count.values()) + except ( + NotASafetensorsRepoError, + SafetensorsParsingError, + GatedRepoError, + RepositoryNotFoundError, + ) as e: + logger.warning( + f"Can't calculate number of parameters for {model_name}. Got error {e}" + ) + return None + + def calculate_num_parameters_from_hub(self) -> int | None: + """Calculates the number of parameters in the model. Returns: - The memory usage of the model in MB, or None if it cannot be determined. + Number of parameters in the model. """ - if "API" in self.framework: - return None + return self._calculate_num_parameters_from_hub(self.name) + @staticmethod + def _calculate_memory_usage_mb( + model_name: str, n_parameters: int | None + ) -> int | None: MB = 1024**2 # noqa: N806 try: - safetensors_metadata = get_safetensors_metadata(self.name) # type: ignore + safetensors_metadata = get_safetensors_metadata(model_name) if len(safetensors_metadata.parameter_count) >= 0: dtype_size_map = { "F64": 8, # 64-bit float @@ -322,18 +553,36 @@ def calculate_memory_usage_mb(self) -> int | None: for dtype, parameters in safetensors_metadata.parameter_count.items() ) return round(total_memory_bytes / MB) # Convert to MB + except ( + NotASafetensorsRepoError, + SafetensorsParsingError, + GatedRepoError, + RepositoryNotFoundError, + ) as e: + logger.warning( + f"Can't calculate memory usage for {model_name}. Got error {e}" + ) - except (NotASafetensorsRepoError, SafetensorsParsingError, GatedRepoError): - pass - if self.n_parameters is None: + if n_parameters is None: return None # Model memory in bytes. For FP32 each parameter is 4 bytes. - model_memory_bytes = self.n_parameters * 4 + model_memory_bytes = n_parameters * 4 # Convert to MB model_memory_mb = model_memory_bytes / MB return round(model_memory_mb) + def calculate_memory_usage_mb(self) -> int | None: + """Calculates the memory usage of the model in MB. + + Returns: + The memory usage of the model in MB, or None if it cannot be determined. + """ + if "API" in self.framework or self.name is None: + return None + + return self._calculate_memory_usage_mb(self.model_name, self.n_parameters) + @staticmethod def fetch_release_date(model_name: str) -> StrDate | None: """Fetches the release date from HuggingFace Hub based on the first commit. @@ -341,15 +590,11 @@ def fetch_release_date(model_name: str) -> StrDate | None: Returns: The release date in YYYY-MM-DD format, or None if it cannot be determined. """ - try: - commits = list_repo_commits(repo_id=model_name, repo_type="model") - if commits: - initial_commit = commits[-1] - release_date = initial_commit.created_at.strftime("%Y-%m-%d") - return release_date - except RepositoryNotFoundError: - logger.warning(f"Model repository not found for {model_name}.") - + commits = _get_repo_commits(repo_id=model_name, repo_type="model") + if commits: + initial_commit = commits[-1] + release_date = initial_commit.created_at.strftime("%Y-%m-%d") + return release_date return None def to_python(self) -> str: @@ -464,3 +709,35 @@ def _collect_similar_tasks(dataset: str, visited: set[str]) -> set[str]: similar.update(_collect_similar_tasks(parent, visited)) return similar + + +def _get_repo_commits(repo_id: str, repo_type: str) -> list[GitCommitInfo] | None: + try: + return list_repo_commits(repo_id=repo_id, repo_type=repo_type) + except (GatedRepoError, RepositoryNotFoundError) as e: + logger.warning(f"Can't get commits of {repo_id}: {e}") + return None + + +def _get_json_from_hub( + repo_id: str, file_name: str, repo_type: str, revision: str | None = None +) -> dict[str, Any] | None: + path = _get_file_on_hub(repo_id, file_name, repo_type, revision) + if path is None: + return None + + with Path(path).open() as f: + js = json.load(f) + return js + + +def _get_file_on_hub( + repo_id: str, file_name: str, repo_type: str, revision: str | None = None +) -> str | None: + try: + return hf_hub_download( + repo_id=repo_id, filename=file_name, repo_type=repo_type, revision=revision + ) + except (GatedRepoError, RepositoryNotFoundError, EntryNotFoundError) as e: + logger.warning(f"Can't get file {file_name} of {repo_id}: {e}") + return None diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index d9028f759e..2330d97037 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -68,11 +68,8 @@ def __init__( self.model = SentenceTransformer(model, revision=revision, **kwargs) else: self.model = model - from mteb.models.get_model_meta import ( - _model_meta_from_sentence_transformers, - ) - self.mteb_model_meta = _model_meta_from_sentence_transformers(self.model) + self.mteb_model_meta = ModelMeta.from_sentence_transformer_model(self.model) built_in_prompts = getattr(self.model, "prompts", None) if built_in_prompts and not model_prompts: @@ -268,14 +265,12 @@ def __init__( ) -> None: from sentence_transformers import CrossEncoder - from mteb.models.get_model_meta import _model_meta_from_cross_encoder - if isinstance(model, CrossEncoder): self.model = model elif isinstance(model, str): self.model = CrossEncoder(model, revision=revision, **kwargs) - self.mteb_model_meta = _model_meta_from_cross_encoder(self.model) + self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model) def predict( self, diff --git a/tests/mock_models.py b/tests/mock_models.py index 22cff3ab33..5d1d54bb9b 100644 --- a/tests/mock_models.py +++ b/tests/mock_models.py @@ -73,6 +73,13 @@ def encode( def get_sentence_embedding_dimension() -> int: return 10 + def max_seq_length(self) -> int: + return 10 + + @property + def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]: + return "cosine" + class MockSentenceTransformersbf16Encoder(MockSentenceTransformer): mteb_model_meta = ModelMeta( diff --git a/tests/test_deprecated/test_MTEB_create_model_meta.py b/tests/test_deprecated/test_MTEB_create_model_meta.py index 89c9ecfaed..132f3a010d 100644 --- a/tests/test_deprecated/test_MTEB_create_model_meta.py +++ b/tests/test_deprecated/test_MTEB_create_model_meta.py @@ -18,7 +18,7 @@ def test_create_model_meta_from_sentence_transformers(): assert meta.embed_dim == model.get_sentence_embedding_dimension() assert type(meta.framework) is list - assert meta.framework[0] == "Sentence Transformers" + assert "Sentence Transformers" in meta.framework assert meta.name == model_name assert meta.revision == revision diff --git a/tests/test_integrations/test_encode_args_passed.py b/tests/test_integrations/test_encode_args_passed.py index 8dd07ec1e1..18c8685771 100644 --- a/tests/test_integrations/test_encode_args_passed.py +++ b/tests/test_integrations/test_encode_args_passed.py @@ -13,6 +13,7 @@ import mteb from mteb.abstasks import AbsTask from mteb.abstasks.task_metadata import TaskMetadata +from mteb.models import ModelMeta from mteb.models.abs_encoder import AbsEncoder from mteb.types import Array, BatchedInput, PromptType from tests.task_grid import MOCK_MIEB_TASK_GRID, MOCK_TASK_TEST_GRID @@ -50,6 +51,28 @@ def test_task_metadata_passed_encoder(task: mteb.AbsTask, tmp_path: Path): _task_name = task.metadata.name class MockEncoder(AbsEncoder): + mteb_model_meta = ModelMeta( + loader=None, + name="no_model_name/available", + revision="no_revision_available", + reference=None, + release_date=None, + languages=None, + license=None, + framework=[], + training_datasets=None, + similarity_fn_name=None, + n_parameters=None, + memory_usage_mb=None, + max_tokens=None, + embed_dim=None, + open_weights=True, + public_training_code=None, + public_training_data=None, + use_instructions=None, + modalities=[], + ) + def encode( self, inputs: DataLoader[BatchedInput], diff --git a/tests/test_integrations/test_integration_with_sentencetransformers.py b/tests/test_integrations/test_integration_with_sentencetransformers.py index 3c185035f6..5932bb61ed 100644 --- a/tests/test_integrations/test_integration_with_sentencetransformers.py +++ b/tests/test_integrations/test_integration_with_sentencetransformers.py @@ -7,6 +7,7 @@ import mteb from mteb.abstasks import AbsTask +from mteb.models import ModelMeta from tests.mock_tasks import ( MockInstructionReranking, MockRerankingTask, @@ -44,3 +45,56 @@ def test_sentence_transformer_integration_cross_encoder(task: AbsTask, model_nam """Test that a task can be fetched and run""" model = CrossEncoder(model_name) mteb.evaluate(model, task, cache=None) + + +def test_model_meta_load_sentence_transformer_metadata_from_model(): + # used also in test CLI + model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + meta = ModelMeta.from_sentence_transformer_model(model) + + assert meta.name == "sentence-transformers/all-MiniLM-L6-v2" + assert meta.revision is not None + assert meta.release_date == "2021-08-30" + assert meta.n_parameters == 22713728 + assert meta.memory_usage_mb == 87 + assert meta.max_tokens == 256 + assert meta.embed_dim == 384 + assert meta.license == "apache-2.0" + assert meta.similarity_fn_name.value == "cosine" + + +@pytest.mark.parametrize("as_sentence_transformer", [True, False]) +@pytest.mark.parametrize("model_name", ["sentence-transformers/all-MiniLM-L6-v2"]) +def test_model_meta_sentence_transformer_from_hub(as_sentence_transformer, model_name): + if as_sentence_transformer: + meta = ModelMeta.from_hub(model_name) + else: + meta = ModelMeta._from_hub(model_name) + + assert meta.name == "sentence-transformers/all-MiniLM-L6-v2" + assert meta.revision is not None + assert meta.release_date == "2021-08-30" + assert meta.n_parameters == 22713728 + assert meta.memory_usage_mb == 87 + assert meta.embed_dim == 384 + assert meta.license == "apache-2.0" + # model have max_position_embeddings 512, but in sentence_bert_config 256 + if as_sentence_transformer: + assert meta.similarity_fn_name.value == "cosine" + assert meta.max_tokens == 256 + else: + assert meta.max_tokens == 512 + + +@pytest.mark.parametrize("model_name", ["cross-encoder/ms-marco-TinyBERT-L2-v2"]) +def test_cross_encoder_model_meta(model_name: str): + model = CrossEncoder(model_name) + meta = ModelMeta.from_cross_encoder(model) + + assert meta.name == "cross-encoder/ms-marco-TinyBERT-L2-v2" + assert meta.revision is not None + assert meta.release_date == "2021-04-15" + assert meta.n_parameters == 4386561 + assert meta.memory_usage_mb == 17 + assert meta.max_tokens == 512 + assert meta.license == "apache-2.0"