Skip to content

Commit

Permalink
Added integration test for four model types
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Jan 23, 2024
1 parent f6f71db commit 61ff3bb
Showing 1 changed file with 43 additions and 8 deletions.
51 changes: 43 additions & 8 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pytest

import seb
from seb.cli import cli, run_benchmark_cli

Expand All @@ -30,13 +31,29 @@ def to_command(self, output_path: Path) -> list[str]:
cli_command_parametrize = pytest.mark.parametrize(
"inputs",
[
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.448, None, None),
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.550, tasks=["DKHate"]),
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.525, tasks=["DKHate", "ScaLA"]),
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.487, languages=["sv", "no", "nn"]),
BenchmarkCliTestInput("sentence-transformers/all-MiniLM-L6-v2", 0.423, languages=["da"]),
BenchmarkCliTestInput(
"test_model", np.nan, code_path=(test_dir / "benchmark_cli_code_inject.py"), tasks=["test-encode-task"], ignore_cache=True
"sentence-transformers/all-MiniLM-L6-v2", 0.448, None, None
),
BenchmarkCliTestInput(
"sentence-transformers/all-MiniLM-L6-v2", 0.550, tasks=["DKHate"]
),
BenchmarkCliTestInput(
"sentence-transformers/all-MiniLM-L6-v2", 0.525, tasks=["DKHate", "ScaLA"]
),
BenchmarkCliTestInput(
"sentence-transformers/all-MiniLM-L6-v2",
0.487,
languages=["sv", "no", "nn"],
),
BenchmarkCliTestInput(
"sentence-transformers/all-MiniLM-L6-v2", 0.423, languages=["da"]
),
BenchmarkCliTestInput(
"test_model",
np.nan,
code_path=(test_dir / "benchmark_cli_code_inject.py"),
tasks=["test-encode-task"],
ignore_cache=True,
),
],
)
Expand Down Expand Up @@ -66,7 +83,9 @@ def test_run_benchmark_cli(inputs: BenchmarkCliTestInput, tmp_path: Path):
res = load_results(tmp_path)
assert len(res) == 1
bench_res = res[0]
bench_res.task_results = [tr for tr in bench_res.task_results if tr.task_name != "test-encode-task"]
bench_res.task_results = [
tr for tr in bench_res.task_results if tr.task_name != "test-encode-task"
]
assert is_approximately_equal(bench_res.get_main_score(), inputs.score)


Expand All @@ -77,5 +96,21 @@ def test_run_cli(inputs: BenchmarkCliTestInput, tmp_path: Path):

assert len(res) == 1
bench_res = res[0]
bench_res.task_results = [tr for tr in bench_res.task_results if tr.task_name != "test-encode-task"]
bench_res.task_results = [
tr for tr in bench_res.task_results if tr.task_name != "test-encode-task"
]
assert is_approximately_equal(bench_res.get_main_score(), inputs.score)


def test_run_some_models():
"""Runs all sorts of models on a small task to see if they can run without breaking.
Cache is ignored so that the models are actually run.
"""
models = [
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/e5-small",
"translate-e5-small",
"fasttext-cc-da-300",
]
tasks = ["DKHate"]
run_benchmark_cli(models=models, tasks=tasks, ignore_cache=True)

0 comments on commit 61ff3bb

Please sign in to comment.