Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions mteb/leaderboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import warnings
from pathlib import Path
from typing import Literal
from typing import Literal, get_args
from urllib.parse import urlencode

import cachetools
Expand All @@ -29,10 +29,12 @@
apply_summary_styling_from_benchmark,
)
from mteb.leaderboard.text_segments import ACKNOWLEDGEMENT, FAQ
from mteb.models.model_meta import MODEL_TYPES

logger = logging.getLogger(__name__)

LANGUAGE: list[str] = list({l for t in mteb.get_tasks() for l in t.metadata.languages})
MODEL_TYPE_CHOICES = list(get_args(MODEL_TYPES))


def _load_results(cache: ResultCache) -> BenchmarkResults:
Expand Down Expand Up @@ -187,6 +189,7 @@ def _filter_models(
instructions: bool | None,
max_model_size: int,
zero_shot_setting: Literal["only_zero_shot", "allow_all", "remove_unknown"],
model_types: list[str] | None,
):
lower, upper = 0, max_model_size
# Setting to None, when the user doesn't specify anything
Expand All @@ -205,6 +208,7 @@ def _filter_models(
use_instructions=instructions,
frameworks=compatibility,
n_parameters_range=(lower, upper),
model_types=model_types,
)

models_to_keep = set()
Expand Down Expand Up @@ -269,6 +273,7 @@ def _cache_on_benchmark_select(benchmark_name, all_benchmark_results):
instructions=None,
max_model_size=MAX_MODEL_SIZE,
zero_shot_setting="allow_all",
model_types=MODEL_TYPE_CHOICES,
)
# Sort to ensure consistency with update_models
initial_models = sorted(initial_models)
Expand Down Expand Up @@ -387,6 +392,7 @@ def get_leaderboard_app(cache: ResultCache = ResultCache()) -> gr.Blocks:
instructions=None,
max_model_size=MAX_MODEL_SIZE,
zero_shot_setting="allow_all",
model_types=MODEL_TYPE_CHOICES,
)
default_filtered_scores = [
entry for entry in default_scores if entry["model_name"] in filtered_models
Expand Down Expand Up @@ -583,6 +589,12 @@ def get_leaderboard_app(cache: ResultCache = ResultCache()) -> gr.Blocks:
label="Model Parameters",
interactive=True,
)
with gr.Column():
model_type_select = gr.CheckboxGroup(
MODEL_TYPE_CHOICES,
value=MODEL_TYPE_CHOICES,
label="Model Type",
)

with gr.Tab("Summary"):
summary_table.render()
Expand Down Expand Up @@ -755,7 +767,8 @@ def update_task_list(
compatibility,
instructions,
max_model_size,
zero_shot: hash(
zero_shot,
model_type_select: hash(
(
id(scores),
hash(tuple(tasks)),
Expand All @@ -764,6 +777,7 @@ def update_task_list(
hash(instructions),
hash(max_model_size),
hash(zero_shot),
hash(tuple(model_type_select)),
)
),
)
Expand All @@ -775,6 +789,7 @@ def update_models(
instructions: bool | None,
max_model_size: int,
zero_shot: Literal["allow_all", "remove_unknown", "only_zero_shot"],
model_type_select: list[str],
):
start_time = time.time()
model_names = list({entry["model_name"] for entry in scores})
Expand All @@ -786,6 +801,7 @@ def update_models(
instructions,
max_model_size,
zero_shot_setting=zero_shot,
model_types=model_type_select,
)
elapsed = time.time() - start_time
logger.debug(f"update_models callback: {elapsed}s")
Expand All @@ -803,6 +819,7 @@ def update_models(
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
Expand All @@ -817,6 +834,7 @@ def update_models(
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
Expand All @@ -830,6 +848,7 @@ def update_models(
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
Expand All @@ -843,6 +862,7 @@ def update_models(
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
Expand All @@ -856,6 +876,7 @@ def update_models(
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
Expand All @@ -869,6 +890,7 @@ def update_models(
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
Expand All @@ -882,6 +904,21 @@ def update_models(
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
model_type_select.change(
update_models,
inputs=[
scores,
task_select,
availability,
compatibility,
instructions,
max_model_size,
zero_shot,
model_type_select,
],
outputs=[models],
)
Expand Down
7 changes: 7 additions & 0 deletions mteb/models/get_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def get_model_metas(
n_parameters_range: tuple[int | None, int | None] = (None, None),
use_instructions: bool | None = None,
zero_shot_on: list[AbsTask] | None = None,
model_types: Iterable[str] | None = None,
) -> list[ModelMeta]:
"""Load all models' metadata that fit the specified criteria.

Expand All @@ -33,6 +34,7 @@ def get_model_metas(
If (None, None), this filter is ignored.
use_instructions: Whether to filter by models that use instructions. If None, all models are included.
zero_shot_on: A list of tasks on which the model is zero-shot. If None this filter is ignored.
model_types: A list of model types to filter by. If None, all model types are included.

Returns:
A list of model metadata objects that fit the specified criteria.
Expand All @@ -41,6 +43,7 @@ def get_model_metas(
model_names = set(model_names) if model_names is not None else None
languages = set(languages) if languages is not None else None
frameworks = set(frameworks) if frameworks is not None else None
model_types_set = set(model_types) if model_types is not None else None
for model_meta in MODEL_REGISTRY.values():
if (model_names is not None) and (model_meta.name not in model_names):
continue
Expand All @@ -57,6 +60,10 @@ def get_model_metas(
model_meta.use_instructions != use_instructions
):
continue
if model_types_set is not None and not model_types_set.intersection(
model_meta.model_type
):
continue

lower, upper = n_parameters_range
n_parameters = model_meta.n_parameters
Expand Down
9 changes: 9 additions & 0 deletions tests/test_models/test_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def test_check_training_datasets_can_be_derived(model_meta: ModelMeta):
model_meta.get_training_datasets()


@pytest.mark.parametrize("model_type", ["dense", "cross-encoder", "late-interaction"])
def test_get_model_metas_each_model_type(model_type):
"""Test filtering by each individual model type."""
models = mteb.get_model_metas(model_types=[model_type])

for model in models:
assert model_type in model.model_type


def test_loader_kwargs_persisted_in_metadata():
model = mteb.get_model(
"baseline/random-encoder-baseline",
Expand Down