Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 103 additions & 18 deletions mteb/leaderboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from typing import Literal
from urllib.parse import urlencode

import cachetools
import gradio as gr
import pandas as pd
from gradio_rangeslider import RangeSlider

import mteb
from mteb.benchmarks.benchmarks import MTEB_multilingual
from mteb.caching import json_cache
from mteb.leaderboard.figures import performance_size_plot, radar_chart
from mteb.leaderboard.table import scores_to_tables

Expand Down Expand Up @@ -470,7 +470,10 @@ def filter_models(
# This sets the benchmark from the URL query parameters
demo.load(set_benchmark_on_load, inputs=[], outputs=[benchmark_select])

@json_cache
@cachetools.cached(
cache={},
key=lambda benchmark_name: hash(benchmark_name),
)
def on_benchmark_select(benchmark_name):
start_time = time.time()
benchmark = mteb.get_benchmark(benchmark_name)
Expand All @@ -495,7 +498,7 @@ def on_benchmark_select(benchmark_name):
languages,
domains,
types,
[task.metadata.name for task in benchmark.tasks],
sorted([task.metadata.name for task in benchmark.tasks]),
scores,
)

Expand All @@ -505,7 +508,12 @@ def on_benchmark_select(benchmark_name):
outputs=[lang_select, domain_select, type_select, task_select, scores],
)

@json_cache
@cachetools.cached(
cache={},
key=lambda benchmark_name, languages: hash(
(hash(benchmark_name), hash(tuple(languages)))
),
)
def update_scores_on_lang_change(benchmark_name, languages):
start_time = time.time()
benchmark_results = all_benchmark_results[benchmark_name]
Expand All @@ -520,6 +528,17 @@ def update_scores_on_lang_change(benchmark_name, languages):
outputs=[scores],
)

@cachetools.cached(
cache={},
key=lambda benchmark_name, type_select, domain_select, lang_select: hash(
(
hash(benchmark_name),
hash(tuple(type_select)),
hash(tuple(domain_select)),
hash(tuple(lang_select)),
)
),
)
def update_task_list(benchmark_name, type_select, domain_select, lang_select):
start_time = time.time()
tasks_to_keep = []
Expand All @@ -533,7 +552,7 @@ def update_task_list(benchmark_name, type_select, domain_select, lang_select):
tasks_to_keep.append(task.metadata.name)
elapsed = time.time() - start_time
logger.info(f"update_task_list callback: {elapsed}s")
return tasks_to_keep
return sorted(tasks_to_keep)

type_select.input(
update_task_list,
Expand All @@ -551,6 +570,26 @@ def update_task_list(benchmark_name, type_select, domain_select, lang_select):
outputs=[task_select],
)

@cachetools.cached(
cache={},
key=lambda scores,
tasks,
availability,
compatibility,
instructions,
model_size,
zero_shot: hash(
(
id(scores),
hash(tuple(tasks)),
hash(availability),
hash(tuple(compatibility)),
hash(instructions),
hash(model_size),
hash(zero_shot),
)
),
)
def update_models(
scores: list[dict],
tasks: list[str],
Expand All @@ -572,8 +611,11 @@ def update_models(
zero_shot_setting=zero_shot,
)
elapsed = time.time() - start_time
if model_names == filtered_models:
# This indicates that the models should not be filtered
return None
logger.info(f"update_models callback: {elapsed}s")
return filtered_models
return sorted(filtered_models)

scores.change(
update_models,
Expand Down Expand Up @@ -667,49 +709,92 @@ def update_models(
outputs=[models],
)

@cachetools.cached(
cache={},
key=lambda scores, search_query, tasks, models_to_keep, benchmark_name: hash(
(
id(scores),
hash(search_query),
hash(tuple(tasks)),
id(models_to_keep),
hash(benchmark_name),
)
),
)
def update_tables(
scores,
search_query: str,
tasks,
models_to_keep,
benchmark_name: str,
):
start_time = time.time()
tasks = set(tasks)
models_to_keep = set(models_to_keep)
filtered_scores = []
for entry in scores:
if entry["task_name"] not in tasks:
continue
if entry["model_name"] not in models_to_keep:
continue
filtered_scores.append(entry)
benchmark = mteb.get_benchmark(benchmark_name)
benchmark_tasks = {task.metadata.name for task in benchmark.tasks}
if (benchmark_tasks != tasks) or (models_to_keep is not None):
filtered_scores = []
for entry in scores:
if entry["task_name"] not in tasks:
continue
if (models_to_keep is not None) and (
entry["model_name"] not in models_to_keep
):
continue
filtered_scores.append(entry)
else:
filtered_scores = scores
summary, per_task = scores_to_tables(filtered_scores, search_query)
elapsed = time.time() - start_time
logger.info(f"update_tables callback: {elapsed}s")
return summary, per_task

task_select.change(
update_tables,
inputs=[scores, searchbar, task_select, models],
inputs=[scores, searchbar, task_select, models, benchmark_select],
outputs=[summary_table, per_task_table],
)
scores.change(
update_tables,
inputs=[scores, searchbar, task_select, models],
inputs=[scores, searchbar, task_select, models, benchmark_select],
outputs=[summary_table, per_task_table],
)
models.change(
update_tables,
inputs=[scores, searchbar, task_select, models],
inputs=[scores, searchbar, task_select, models, benchmark_select],
outputs=[summary_table, per_task_table],
)
searchbar.submit(
update_tables,
inputs=[scores, searchbar, task_select, models],
inputs=[scores, searchbar, task_select, models, benchmark_select],
outputs=[summary_table, per_task_table],
)

gr.Markdown(acknowledgment_md, elem_id="ack_markdown")


# Prerun on all benchmarks, so that results of callbacks get cached
for benchmark in benchmarks:
bench_languages, bench_domains, bench_types, bench_tasks, bench_scores = (
on_benchmark_select(benchmark.name)
)
filtered_models = update_models(
bench_scores,
bench_tasks,
availability=None,
compatibility=[],
instructions=None,
model_size=(MIN_MODEL_SIZE, MAX_MODEL_SIZE),
zero_shot="soft",
)
# We have to call this both on the filtered and unfiltered task, because the callbacks
# also gets called twice for some reason
update_tables(bench_scores, "", bench_tasks, filtered_models, benchmark.name)
filtered_tasks = update_task_list(
benchmark.name, bench_types, bench_domains, bench_languages
)
update_tables(bench_scores, "", filtered_tasks, filtered_models, benchmark.name)


if __name__ == "__main__":
demo.launch()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ dev = [
codecarbon = ["codecarbon"]
speedtask = ["GPUtil>=1.4.0", "psutil>=5.9.8"]
peft = ["peft>=0.11.0"]
leaderboard = ["gradio>=5.16.0", "gradio_rangeslider>=0.0.8", "plotly>=5.24.0,<6.0.0"]
leaderboard = ["gradio>=5.16.0,<6.0.0", "gradio_rangeslider>=0.0.8", "plotly>=5.24.0,<6.0.0", "cachetools>=5.2.0"]
flagembedding = ["FlagEmbedding"]
jina = ["einops>=0.8.0"]
flash_attention = ["flash-attn>=2.6.3"]
Expand Down
Loading