diff --git a/mteb/leaderboard/app.py b/mteb/leaderboard/app.py index 8e8b40edfb..9a707160c4 100644 --- a/mteb/leaderboard/app.py +++ b/mteb/leaderboard/app.py @@ -9,13 +9,13 @@ from typing import Literal from urllib.parse import urlencode +import cachetools import gradio as gr import pandas as pd from gradio_rangeslider import RangeSlider import mteb from mteb.benchmarks.benchmarks import MTEB_multilingual -from mteb.caching import json_cache from mteb.leaderboard.figures import performance_size_plot, radar_chart from mteb.leaderboard.table import scores_to_tables @@ -470,7 +470,10 @@ def filter_models( # This sets the benchmark from the URL query parameters demo.load(set_benchmark_on_load, inputs=[], outputs=[benchmark_select]) - @json_cache + @cachetools.cached( + cache={}, + key=lambda benchmark_name: hash(benchmark_name), + ) def on_benchmark_select(benchmark_name): start_time = time.time() benchmark = mteb.get_benchmark(benchmark_name) @@ -495,7 +498,7 @@ def on_benchmark_select(benchmark_name): languages, domains, types, - [task.metadata.name for task in benchmark.tasks], + sorted([task.metadata.name for task in benchmark.tasks]), scores, ) @@ -505,7 +508,12 @@ def on_benchmark_select(benchmark_name): outputs=[lang_select, domain_select, type_select, task_select, scores], ) - @json_cache + @cachetools.cached( + cache={}, + key=lambda benchmark_name, languages: hash( + (hash(benchmark_name), hash(tuple(languages))) + ), + ) def update_scores_on_lang_change(benchmark_name, languages): start_time = time.time() benchmark_results = all_benchmark_results[benchmark_name] @@ -520,6 +528,17 @@ def update_scores_on_lang_change(benchmark_name, languages): outputs=[scores], ) + @cachetools.cached( + cache={}, + key=lambda benchmark_name, type_select, domain_select, lang_select: hash( + ( + hash(benchmark_name), + hash(tuple(type_select)), + hash(tuple(domain_select)), + hash(tuple(lang_select)), + ) + ), + ) def update_task_list(benchmark_name, type_select, domain_select, lang_select): start_time = time.time() tasks_to_keep = [] @@ -533,7 +552,7 @@ def update_task_list(benchmark_name, type_select, domain_select, lang_select): tasks_to_keep.append(task.metadata.name) elapsed = time.time() - start_time logger.info(f"update_task_list callback: {elapsed}s") - return tasks_to_keep + return sorted(tasks_to_keep) type_select.input( update_task_list, @@ -551,6 +570,26 @@ def update_task_list(benchmark_name, type_select, domain_select, lang_select): outputs=[task_select], ) + @cachetools.cached( + cache={}, + key=lambda scores, + tasks, + availability, + compatibility, + instructions, + model_size, + zero_shot: hash( + ( + id(scores), + hash(tuple(tasks)), + hash(availability), + hash(tuple(compatibility)), + hash(instructions), + hash(model_size), + hash(zero_shot), + ) + ), + ) def update_models( scores: list[dict], tasks: list[str], @@ -572,8 +611,11 @@ def update_models( zero_shot_setting=zero_shot, ) elapsed = time.time() - start_time + if model_names == filtered_models: + # This indicates that the models should not be filtered + return None logger.info(f"update_models callback: {elapsed}s") - return filtered_models + return sorted(filtered_models) scores.change( update_models, @@ -667,22 +709,41 @@ def update_models( outputs=[models], ) + @cachetools.cached( + cache={}, + key=lambda scores, search_query, tasks, models_to_keep, benchmark_name: hash( + ( + id(scores), + hash(search_query), + hash(tuple(tasks)), + id(models_to_keep), + hash(benchmark_name), + ) + ), + ) def update_tables( scores, search_query: str, tasks, models_to_keep, + benchmark_name: str, ): start_time = time.time() tasks = set(tasks) - models_to_keep = set(models_to_keep) - filtered_scores = [] - for entry in scores: - if entry["task_name"] not in tasks: - continue - if entry["model_name"] not in models_to_keep: - continue - filtered_scores.append(entry) + benchmark = mteb.get_benchmark(benchmark_name) + benchmark_tasks = {task.metadata.name for task in benchmark.tasks} + if (benchmark_tasks != tasks) or (models_to_keep is not None): + filtered_scores = [] + for entry in scores: + if entry["task_name"] not in tasks: + continue + if (models_to_keep is not None) and ( + entry["model_name"] not in models_to_keep + ): + continue + filtered_scores.append(entry) + else: + filtered_scores = scores summary, per_task = scores_to_tables(filtered_scores, search_query) elapsed = time.time() - start_time logger.info(f"update_tables callback: {elapsed}s") @@ -690,26 +751,50 @@ def update_tables( task_select.change( update_tables, - inputs=[scores, searchbar, task_select, models], + inputs=[scores, searchbar, task_select, models, benchmark_select], outputs=[summary_table, per_task_table], ) scores.change( update_tables, - inputs=[scores, searchbar, task_select, models], + inputs=[scores, searchbar, task_select, models, benchmark_select], outputs=[summary_table, per_task_table], ) models.change( update_tables, - inputs=[scores, searchbar, task_select, models], + inputs=[scores, searchbar, task_select, models, benchmark_select], outputs=[summary_table, per_task_table], ) searchbar.submit( update_tables, - inputs=[scores, searchbar, task_select, models], + inputs=[scores, searchbar, task_select, models, benchmark_select], outputs=[summary_table, per_task_table], ) gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + +# Prerun on all benchmarks, so that results of callbacks get cached +for benchmark in benchmarks: + bench_languages, bench_domains, bench_types, bench_tasks, bench_scores = ( + on_benchmark_select(benchmark.name) + ) + filtered_models = update_models( + bench_scores, + bench_tasks, + availability=None, + compatibility=[], + instructions=None, + model_size=(MIN_MODEL_SIZE, MAX_MODEL_SIZE), + zero_shot="soft", + ) + # We have to call this both on the filtered and unfiltered task, because the callbacks + # also gets called twice for some reason + update_tables(bench_scores, "", bench_tasks, filtered_models, benchmark.name) + filtered_tasks = update_task_list( + benchmark.name, bench_types, bench_domains, bench_languages + ) + update_tables(bench_scores, "", filtered_tasks, filtered_models, benchmark.name) + + if __name__ == "__main__": demo.launch() diff --git a/pyproject.toml b/pyproject.toml index bbfea6e3e7..563480406d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dev = [ codecarbon = ["codecarbon"] speedtask = ["GPUtil>=1.4.0", "psutil>=5.9.8"] peft = ["peft>=0.11.0"] -leaderboard = ["gradio>=5.16.0", "gradio_rangeslider>=0.0.8", "plotly>=5.24.0,<6.0.0"] +leaderboard = ["gradio>=5.16.0,<6.0.0", "gradio_rangeslider>=0.0.8", "plotly>=5.24.0,<6.0.0", "cachetools>=5.2.0"] flagembedding = ["FlagEmbedding"] jina = ["einops>=0.8.0"] flash_attention = ["flash-attn>=2.6.3"]