diff --git a/mteb/cli.py b/mteb/cli.py index b891d381f4..65d6938416 100644 --- a/mteb/cli.py +++ b/mteb/cli.py @@ -122,12 +122,16 @@ def run(args: argparse.Namespace) -> None: model = mteb.get_model(args.model, args.model_revision, device=device) - tasks = mteb.get_tasks( - categories=args.categories, - task_types=args.task_types, - languages=args.languages, - tasks=args.tasks, - ) + if args.benchmarks: + tasks = mteb.get_benchmarks(names=args.benchmarks) + else: + tasks = mteb.get_tasks( + categories=args.categories, + task_types=args.task_types, + languages=args.languages, + tasks=args.tasks, + ) + eval = mteb.MTEB(tasks=tasks) encode_kwargs = {} @@ -153,7 +157,7 @@ def run(args: argparse.Namespace) -> None: def available_benchmarks(args: argparse.Namespace) -> None: - benchmarks = mteb.get_benchmarks() + benchmarks = mteb.get_benchmarks(names=args.benchmarks) eval = mteb.MTEB(tasks=benchmarks) eval.mteb_benchmarks() @@ -169,6 +173,18 @@ def available_tasks(args: argparse.Namespace) -> None: eval.mteb_tasks() +def add_benchmark_selection_args(parser: argparse.ArgumentParser) -> None: + """Adds arguments to the parser for filtering benchmarks by name.""" + parser.add_argument( + "-b", + "--benchmarks", + nargs="+", + type=str, + default=None, + help="List of benchmark to be evaluated.", + ) + + def add_task_selection_args(parser: argparse.ArgumentParser) -> None: """Adds arguments to the parser for filtering tasks by type, category, language, and task name.""" parser.add_argument( @@ -216,7 +232,7 @@ def add_available_benchmarks_parser(subparsers) -> None: parser = subparsers.add_parser( "available_benchmarks", help="List the available benchmarks within MTEB" ) - add_task_selection_args(parser) + add_benchmark_selection_args(parser) parser.set_defaults(func=available_benchmarks) @@ -232,6 +248,7 @@ def add_run_parser(subparsers) -> None: ) add_task_selection_args(parser) + add_benchmark_selection_args(parser) parser.add_argument( "--device", type=int, default=None, help="Device to use for computation" diff --git a/mteb/evaluation/MTEB.py b/mteb/evaluation/MTEB.py index 70f3e21ca8..d28b3a1bf6 100644 --- a/mteb/evaluation/MTEB.py +++ b/mteb/evaluation/MTEB.py @@ -6,6 +6,7 @@ import traceback from copy import copy from datetime import datetime +from itertools import chain from pathlib import Path from time import time from typing import Any, Iterable @@ -52,12 +53,17 @@ def __init__( err_logs_path: Path to save error logs. kwargs: Additional arguments to be passed to the tasks """ + from mteb.benchmarks import Benchmark + self.deprecation_warning( task_types, task_categories, task_langs, tasks, version ) if tasks is not None: self._tasks = tasks + if isinstance(tasks[0], Benchmark): + self.benchmarks = tasks + self._tasks = list(chain.from_iterable(tasks)) assert ( task_types is None and task_categories is None ), "Cannot specify both `tasks` and `task_types`/`task_categories`" @@ -170,7 +176,7 @@ def _display_tasks(self, task_list, name=None): def mteb_benchmarks(self): """Get all benchmarks available in the MTEB.""" - for benchmark in self._tasks: + for benchmark in self.benchmarks: name = benchmark.name self._display_tasks(benchmark.tasks, name=name) diff --git a/tests/test_benchmark/test_benchmark.py b/tests/test_benchmark/test_benchmark.py index 612705fe72..e29d4be5ce 100644 --- a/tests/test_benchmark/test_benchmark.py +++ b/tests/test_benchmark/test_benchmark.py @@ -159,6 +159,19 @@ def test_run_using_benchmark(model: mteb.Encoder): ) # we just want to test that it runs +@pytest.mark.parametrize("model", [MockNumpyEncoder()]) +def test_run_using_list_of_benchmark(model: mteb.Encoder): + """Test that a list of benchmark objects can be run using the MTEB class.""" + bench = [ + Benchmark(name="test_bench", tasks=mteb.get_tasks(tasks=["STS12", "SummEval"])) + ] + + eval = mteb.MTEB(tasks=bench) + eval.run( + model, output_folder="tests/results", overwrite_results=True + ) # we just want to test that it runs + + def test_benchmark_names_must_be_unique(): import mteb.benchmarks.benchmarks as benchmark_module diff --git a/tests/test_cli.py b/tests/test_cli.py index 1d0400e985..00f7483f50 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -28,7 +28,7 @@ def test_available_benchmarks(): assert result.returncode == 0, "Command failed" assert ( "MTEB(eng)" in result.stdout - ), "Sample benchmark MTEB(eng) task not found in available bencmarks" + ), "Sample benchmark MTEB(eng) task not found in available benchmarks" run_task_fixures = [ @@ -65,6 +65,7 @@ def test_run_task( co2_tracker=None, overwrite=True, eval_splits=None, + benchmarks=None, ) run(args)