diff --git a/mteb/leaderboard/app.py b/mteb/leaderboard/app.py index 0cf8cae482..c2cbb23a8d 100644 --- a/mteb/leaderboard/app.py +++ b/mteb/leaderboard/app.py @@ -5,7 +5,7 @@ import time import warnings from pathlib import Path -from typing import Literal +from typing import Literal, get_args from urllib.parse import urlencode import cachetools @@ -29,10 +29,12 @@ apply_summary_styling_from_benchmark, ) from mteb.leaderboard.text_segments import ACKNOWLEDGEMENT, FAQ +from mteb.models.model_meta import MODEL_TYPES logger = logging.getLogger(__name__) LANGUAGE: list[str] = list({l for t in mteb.get_tasks() for l in t.metadata.languages}) +MODEL_TYPE_CHOICES = list(get_args(MODEL_TYPES)) def _load_results(cache: ResultCache) -> BenchmarkResults: @@ -187,6 +189,7 @@ def _filter_models( instructions: bool | None, max_model_size: int, zero_shot_setting: Literal["only_zero_shot", "allow_all", "remove_unknown"], + model_types: list[str] | None, ): lower, upper = 0, max_model_size # Setting to None, when the user doesn't specify anything @@ -205,6 +208,7 @@ def _filter_models( use_instructions=instructions, frameworks=compatibility, n_parameters_range=(lower, upper), + model_types=model_types, ) models_to_keep = set() @@ -269,6 +273,7 @@ def _cache_on_benchmark_select(benchmark_name, all_benchmark_results): instructions=None, max_model_size=MAX_MODEL_SIZE, zero_shot_setting="allow_all", + model_types=MODEL_TYPE_CHOICES, ) # Sort to ensure consistency with update_models initial_models = sorted(initial_models) @@ -387,6 +392,7 @@ def get_leaderboard_app(cache: ResultCache = ResultCache()) -> gr.Blocks: instructions=None, max_model_size=MAX_MODEL_SIZE, zero_shot_setting="allow_all", + model_types=MODEL_TYPE_CHOICES, ) default_filtered_scores = [ entry for entry in default_scores if entry["model_name"] in filtered_models @@ -583,6 +589,12 @@ def get_leaderboard_app(cache: ResultCache = ResultCache()) -> gr.Blocks: label="Model Parameters", interactive=True, ) + with gr.Column(): + model_type_select = gr.CheckboxGroup( + MODEL_TYPE_CHOICES, + value=MODEL_TYPE_CHOICES, + label="Model Type", + ) with gr.Tab("Summary"): summary_table.render() @@ -755,7 +767,8 @@ def update_task_list( compatibility, instructions, max_model_size, - zero_shot: hash( + zero_shot, + model_type_select: hash( ( id(scores), hash(tuple(tasks)), @@ -764,6 +777,7 @@ def update_task_list( hash(instructions), hash(max_model_size), hash(zero_shot), + hash(tuple(model_type_select)), ) ), ) @@ -775,6 +789,7 @@ def update_models( instructions: bool | None, max_model_size: int, zero_shot: Literal["allow_all", "remove_unknown", "only_zero_shot"], + model_type_select: list[str], ): start_time = time.time() model_names = list({entry["model_name"] for entry in scores}) @@ -786,6 +801,7 @@ def update_models( instructions, max_model_size, zero_shot_setting=zero_shot, + model_types=model_type_select, ) elapsed = time.time() - start_time logger.debug(f"update_models callback: {elapsed}s") @@ -803,6 +819,7 @@ def update_models( instructions, max_model_size, zero_shot, + model_type_select, ], outputs=[models], ) @@ -817,6 +834,7 @@ def update_models( instructions, max_model_size, zero_shot, + model_type_select, ], outputs=[models], ) @@ -830,6 +848,7 @@ def update_models( instructions, max_model_size, zero_shot, + model_type_select, ], outputs=[models], ) @@ -843,6 +862,7 @@ def update_models( instructions, max_model_size, zero_shot, + model_type_select, ], outputs=[models], ) @@ -856,6 +876,7 @@ def update_models( instructions, max_model_size, zero_shot, + model_type_select, ], outputs=[models], ) @@ -869,6 +890,7 @@ def update_models( instructions, max_model_size, zero_shot, + model_type_select, ], outputs=[models], ) @@ -882,6 +904,21 @@ def update_models( instructions, max_model_size, zero_shot, + model_type_select, + ], + outputs=[models], + ) + model_type_select.change( + update_models, + inputs=[ + scores, + task_select, + availability, + compatibility, + instructions, + max_model_size, + zero_shot, + model_type_select, ], outputs=[models], ) diff --git a/mteb/models/get_model_meta.py b/mteb/models/get_model_meta.py index 4c83d29372..24ddaed6e4 100644 --- a/mteb/models/get_model_meta.py +++ b/mteb/models/get_model_meta.py @@ -21,6 +21,7 @@ def get_model_metas( n_parameters_range: tuple[int | None, int | None] = (None, None), use_instructions: bool | None = None, zero_shot_on: list[AbsTask] | None = None, + model_types: Iterable[str] | None = None, ) -> list[ModelMeta]: """Load all models' metadata that fit the specified criteria. @@ -33,6 +34,7 @@ def get_model_metas( If (None, None), this filter is ignored. use_instructions: Whether to filter by models that use instructions. If None, all models are included. zero_shot_on: A list of tasks on which the model is zero-shot. If None this filter is ignored. + model_types: A list of model types to filter by. If None, all model types are included. Returns: A list of model metadata objects that fit the specified criteria. @@ -41,6 +43,7 @@ def get_model_metas( model_names = set(model_names) if model_names is not None else None languages = set(languages) if languages is not None else None frameworks = set(frameworks) if frameworks is not None else None + model_types_set = set(model_types) if model_types is not None else None for model_meta in MODEL_REGISTRY.values(): if (model_names is not None) and (model_meta.name not in model_names): continue @@ -57,6 +60,10 @@ def get_model_metas( model_meta.use_instructions != use_instructions ): continue + if model_types_set is not None and not model_types_set.intersection( + model_meta.model_type + ): + continue lower, upper = n_parameters_range n_parameters = model_meta.n_parameters diff --git a/tests/test_models/test_model_meta.py b/tests/test_models/test_model_meta.py index 320ee801f1..bfb676fdf8 100644 --- a/tests/test_models/test_model_meta.py +++ b/tests/test_models/test_model_meta.py @@ -127,6 +127,15 @@ def test_check_training_datasets_can_be_derived(model_meta: ModelMeta): model_meta.get_training_datasets() +@pytest.mark.parametrize("model_type", ["dense", "cross-encoder", "late-interaction"]) +def test_get_model_metas_each_model_type(model_type): + """Test filtering by each individual model type.""" + models = mteb.get_model_metas(model_types=[model_type]) + + for model in models: + assert model_type in model.model_type + + def test_loader_kwargs_persisted_in_metadata(): model = mteb.get_model( "baseline/random-encoder-baseline",