diff --git a/mteb/models/get_model_meta.py b/mteb/models/get_model_meta.py index 23c1b860b2..d5c64b7f59 100644 --- a/mteb/models/get_model_meta.py +++ b/mteb/models/get_model_meta.py @@ -105,7 +105,10 @@ def get_model( def get_model_meta( - model_name: str, revision: str | None = None, fetch_from_hf: bool = True + model_name: str, + revision: str | None = None, + fetch_from_hf: bool = True, + fill_missing: bool = False, ) -> ModelMeta: """A function to fetch a model metadata object by name. @@ -113,6 +116,7 @@ def get_model_meta( model_name: Name of the model to fetch revision: Revision of the model to fetch fetch_from_hf: Whether to fetch the model from HuggingFace Hub if not found in the registry + fill_missing: Computes missing attributes from the metadata including number of parameters and memory usage. Returns: A model metadata object @@ -124,10 +128,25 @@ def get_model_meta( raise ValueError( f"Model revision {revision} not found for model {model_name}. Expected {model_meta.revision}." ) + + if fill_missing and fetch_from_hf: + original_meta_dict = model_meta.model_dump() + new_meta = ModelMeta.from_hub(model_name) + new_meta_dict = new_meta.model_dump(exclude_none=True) + + updates = { + k: v + for k, v in new_meta_dict.items() + if original_meta_dict.get(k) is None + } + + if updates: + return model_meta.model_copy(update=updates) return model_meta + if fetch_from_hf: logger.info( - "Model not found in model registry. Attempting to extract metadata by loading the model ({model_name}) using HuggingFace." + f"Model not found in model registry. Attempting to extract metadata by loading the model ({model_name}) using HuggingFace." ) meta = ModelMeta.from_hub(model_name, revision) return meta diff --git a/tests/test_models/test_model_meta.py b/tests/test_models/test_model_meta.py index 320ee801f1..64a98bca28 100644 --- a/tests/test_models/test_model_meta.py +++ b/tests/test_models/test_model_meta.py @@ -140,6 +140,15 @@ def test_loader_kwargs_persisted_in_metadata(): assert meta.loader_kwargs["not_existing_param"] == 123 +def test_compute_missing_parameter(): + """Test that compute_missing parameter fetches missing metadata from HuggingFace Hub""" + model_name = "sentence-transformers/all-MiniLM-L6-v2" + meta_with_compute = mteb.get_model_meta(model_name, fill_missing=True) + + assert meta_with_compute.n_parameters is not None + assert meta_with_compute.memory_usage_mb is not None + + def test_model_to_python(): meta = mteb.get_model_meta("sentence-transformers/all-MiniLM-L6-v2") assert meta.to_python() == (