Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions mteb/models/get_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,18 @@ def get_model(


def get_model_meta(
model_name: str, revision: str | None = None, fetch_from_hf: bool = True
model_name: str,
revision: str | None = None,
fetch_from_hf: bool = True,
fill_missing: bool = False,
) -> ModelMeta:
"""A function to fetch a model metadata object by name.

Args:
model_name: Name of the model to fetch
revision: Revision of the model to fetch
fetch_from_hf: Whether to fetch the model from HuggingFace Hub if not found in the registry
fill_missing: Computes missing attributes from the metadata including number of parameters and memory usage.

Returns:
A model metadata object
Expand All @@ -124,10 +128,25 @@ def get_model_meta(
raise ValueError(
f"Model revision {revision} not found for model {model_name}. Expected {model_meta.revision}."
)

if fill_missing and fetch_from_hf:
original_meta_dict = model_meta.model_dump()
new_meta = ModelMeta.from_hub(model_name)
new_meta_dict = new_meta.model_dump(exclude_none=True)

updates = {
k: v
for k, v in new_meta_dict.items()
if original_meta_dict.get(k) is None
}

if updates:
return model_meta.model_copy(update=updates)
return model_meta

if fetch_from_hf:
logger.info(
"Model not found in model registry. Attempting to extract metadata by loading the model ({model_name}) using HuggingFace."
f"Model not found in model registry. Attempting to extract metadata by loading the model ({model_name}) using HuggingFace."
)
meta = ModelMeta.from_hub(model_name, revision)
return meta
Expand Down
9 changes: 9 additions & 0 deletions tests/test_models/test_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def test_loader_kwargs_persisted_in_metadata():
assert meta.loader_kwargs["not_existing_param"] == 123


def test_compute_missing_parameter():
"""Test that compute_missing parameter fetches missing metadata from HuggingFace Hub"""
model_name = "sentence-transformers/all-MiniLM-L6-v2"
meta_with_compute = mteb.get_model_meta(model_name, fill_missing=True)

assert meta_with_compute.n_parameters is not None
assert meta_with_compute.memory_usage_mb is not None


def test_model_to_python():
meta = mteb.get_model_meta("sentence-transformers/all-MiniLM-L6-v2")
assert meta.to_python() == (
Expand Down