From 032002a54c74a4ec155396595f100c7e6563a39e Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 16 Dec 2024 10:39:21 -1000 Subject: [PATCH 1/7] SUT arg(s) is now required by CLI --- src/modelbench/run.py | 36 +++++++++++++----------------- tests/modelbench_tests/test_run.py | 31 +++++++++++++------------ 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/modelbench/run.py b/src/modelbench/run.py index c26a7cb8..c697f403 100644 --- a/src/modelbench/run.py +++ b/src/modelbench/run.py @@ -78,7 +78,7 @@ def cli() -> None: @click.option("--max-instances", "-m", type=int, default=100) @click.option("--debug", default=False, is_flag=True) @click.option("--json-logs", default=False, is_flag=True, help="Print only machine-readable progress reports") -@click.option("sut_uids", "--sut", "-s", multiple=True, help="SUT uid(s) to run") +@click.option("sut_uids", "--sut", "-s", multiple=True, help="SUT uid(s) to run", required=True) @click.option("--anonymize", type=int, help="Random number seed for consistent anonymization of SUTs") @click.option("--parallel", default=False, help="Obsolete flag, soon to be removed") @click.option( @@ -189,25 +189,21 @@ def consistency_check(journal_path, verbose): print("\t", j) -def find_suts_for_sut_argument(sut_args: List[str]): - if sut_args: - suts = [] - default_suts_by_key = {s.key: s for s in DEFAULT_SUTS} - registered_sut_keys = set(i[0] for i in SUTS.items()) - for sut_arg in sut_args: - if sut_arg in default_suts_by_key: - suts.append(default_suts_by_key[sut_arg]) - elif sut_arg in registered_sut_keys: - suts.append(ModelGaugeSut.for_key(sut_arg)) - else: - all_sut_keys = registered_sut_keys.union(set(default_suts_by_key.keys())) - raise click.BadParameter( - f"Unknown key '{sut_arg}'. Valid options are {sorted(all_sut_keys, key=lambda x: x.lower())}", - param_hint="sut", - ) - - else: - suts = DEFAULT_SUTS +def find_suts_for_sut_argument(sut_uids: List[str]): + suts = [] + default_suts_by_key = {s.key: s for s in DEFAULT_SUTS} + registered_sut_keys = set(i[0] for i in SUTS.items()) + for sut_uid in sut_uids: + if sut_uid in default_suts_by_key: + suts.append(default_suts_by_key[sut_uid]) + elif sut_uid in registered_sut_keys: + suts.append(ModelGaugeSut.for_key(sut_uid)) + else: + all_sut_keys = registered_sut_keys.union(set(default_suts_by_key.keys())) + raise click.BadParameter( + f"Unknown uid '{sut_uid}'. Valid options are {sorted(all_sut_keys, key=lambda x: x.lower())}", + param_hint="sut", + ) return suts diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index 9ce76147..8afa3a7f 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -52,9 +52,6 @@ def _make_hazards(self) -> Sequence[HazardDefinition]: def test_find_suts(): - # nothing gets everything - assert find_suts_for_sut_argument([]) == DEFAULT_SUTS - # key from modelbench gets a known SUT assert find_suts_for_sut_argument(["mistral-7b"]) == [ModelGaugeSut.for_key("mistral-7b")] @@ -75,10 +72,14 @@ def _make_hazards(self) -> Sequence[HazardDefinition]: def uid(self): return "my_benchmark" + @pytest.fixture + def sut_uid(self): + return "demo_yes_no" + def mock_score( self, benchmark=GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, "practice"), - sut=ModelGaugeSut.for_key("mistral-7b"), + sut=ModelGaugeSut.for_key("demo_yes_no"), ): return BenchmarkScore( benchmark, @@ -225,16 +226,18 @@ def test_invalid_benchmark_versions_can_not_be_called(self, version, runner): assert "Invalid value for '--version'" in result.output @pytest.mark.skip(reason="we have temporarily removed other languages") - def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_score_benchmarks): - result = runner.invoke(cli, ["benchmark", "--locale", "fr_FR"]) + def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_score_benchmarks, sut_uid): + result = runner.invoke(cli, ["benchmark", "--locale", "fr_FR", "--sut", sut_uid]) benchmark_arg = mock_score_benchmarks.call_args.args[0][0] assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1) assert benchmark_arg.locale == Locale.FR_FR @pytest.mark.skip(reason="we have temporarily removed other languages") - def test_calls_score_benchmark_all_locales(self, runner, mock_score_benchmarks, tmp_path): - result = runner.invoke(cli, ["benchmark", "--locale", "all", "--output-dir", str(tmp_path.absolute())]) + def test_calls_score_benchmark_all_locales(self, runner, mock_score_benchmarks, sut_uid, tmp_path): + result = runner.invoke( + cli, ["benchmark", "--locale", "all", "--output-dir", str(tmp_path.absolute()), "--sut", sut_uid] + ) benchmark_args = mock_score_benchmarks.call_args.args[0] locales = set([benchmark_arg.locale for benchmark_arg in benchmark_args]) @@ -251,22 +254,22 @@ def test_calls_score_benchmark_all_locales(self, runner, mock_score_benchmarks, # benchmark_arg = mock_score_benchmarks.call_args.args[0][0] # assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmark) - def test_v1_en_us_practice_is_default(self, runner, mock_score_benchmarks): - result = runner.invoke(cli, ["benchmark"]) + def test_v1_en_us_practice_is_default(self, runner, mock_score_benchmarks, sut_uid): + result = runner.invoke(cli, ["benchmark", "--sut", sut_uid]) benchmark_arg = mock_score_benchmarks.call_args.args[0][0] assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1) assert benchmark_arg.locale == Locale.EN_US assert benchmark_arg.prompt_set == "practice" - def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner): - result = runner.invoke(cli, ["benchmark", "--prompt-set", "fake"]) + def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner, sut_uid): + result = runner.invoke(cli, ["benchmark", "--prompt-set", "fake", "--sut", sut_uid]) assert result.exit_code == 2 assert "Invalid value for '--prompt-set'" in result.output @pytest.mark.parametrize("prompt_set", PROMPT_SETS.keys()) - def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_score_benchmarks, prompt_set): - result = runner.invoke(cli, ["benchmark", "--prompt-set", prompt_set]) + def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_score_benchmarks, prompt_set, sut_uid): + result = runner.invoke(cli, ["benchmark", "--prompt-set", prompt_set, "--sut", sut_uid]) benchmark_arg = mock_score_benchmarks.call_args.args[0][0] assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmarkV1) From 3bf34358473547fdbf0348dfd0c3eda8d2774f14 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 16 Dec 2024 10:51:58 -1000 Subject: [PATCH 2/7] Get rid of DEFFAULT_SUTS --- src/modelbench/run.py | 10 +++------- src/modelbench/suts.py | 10 ---------- tests/modelbench_tests/test_benchmark_grading.py | 11 ++++++++--- tests/modelbench_tests/test_run.py | 2 +- 4 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/modelbench/run.py b/src/modelbench/run.py index c697f403..de638765 100644 --- a/src/modelbench/run.py +++ b/src/modelbench/run.py @@ -26,7 +26,7 @@ from modelbench.consistency_checker import ConsistencyChecker, summarize_consistency_check_results from modelbench.hazards import STANDARDS from modelbench.record import dump_json -from modelbench.suts import ModelGaugeSut, SutDescription, DEFAULT_SUTS +from modelbench.suts import ModelGaugeSut, SutDescription from modelgauge.config import load_secrets_from_config, write_default_config from modelgauge.load_plugins import load_plugins from modelgauge.sut_registry import SUTS @@ -191,17 +191,13 @@ def consistency_check(journal_path, verbose): def find_suts_for_sut_argument(sut_uids: List[str]): suts = [] - default_suts_by_key = {s.key: s for s in DEFAULT_SUTS} registered_sut_keys = set(i[0] for i in SUTS.items()) for sut_uid in sut_uids: - if sut_uid in default_suts_by_key: - suts.append(default_suts_by_key[sut_uid]) - elif sut_uid in registered_sut_keys: + if sut_uid in registered_sut_keys: suts.append(ModelGaugeSut.for_key(sut_uid)) else: - all_sut_keys = registered_sut_keys.union(set(default_suts_by_key.keys())) raise click.BadParameter( - f"Unknown uid '{sut_uid}'. Valid options are {sorted(all_sut_keys, key=lambda x: x.lower())}", + f"Unknown uid '{sut_uid}'. Valid options are {sorted(registered_sut_keys, key=lambda x: x.lower())}", param_hint="sut", ) return suts diff --git a/src/modelbench/suts.py b/src/modelbench/suts.py index fb7cec8a..109e0922 100644 --- a/src/modelbench/suts.py +++ b/src/modelbench/suts.py @@ -64,13 +64,3 @@ def _register_required_suts(): _register_required_suts() -# TODO: This was originally SUTS_FOR_V_0_5. Do we still need this? And if so, is this the best place for it? -DEFAULT_SUTS = [ - ModelGaugeSut.for_key(k) - for k in [ - "deepseek-67b", - "llama-2-13b-chat", - "mistral-7b", - "mixtral-8x-7b", - ] -] diff --git a/tests/modelbench_tests/test_benchmark_grading.py b/tests/modelbench_tests/test_benchmark_grading.py index f3512748..3e84db03 100644 --- a/tests/modelbench_tests/test_benchmark_grading.py +++ b/tests/modelbench_tests/test_benchmark_grading.py @@ -7,7 +7,7 @@ from modelbench.benchmarks import BenchmarkScore, GeneralPurposeAiChatBenchmarkV1, Locale from modelbench.hazards import HazardScore, SafeHazardV1, Standards from modelbench.scoring import ValueEstimate -from modelbench.suts import DEFAULT_SUTS +from modelbench.suts import ModelGaugeSut # NOTE: the following functions behave like fixtures, but they are not marked as @pytest.fixture @@ -138,6 +138,11 @@ def poor_standards(): return Standards(pathlib.Path(__file__).parent / "data" / "standards_poor.json") +@pytest.fixture +def sut(): + return ModelGaugeSut.for_key("demo_yes_no") + + @pytest.mark.parametrize( "hazard_scores,total_scored,total_safe,frac_safe,standards,expected_grade", [ @@ -172,11 +177,11 @@ def poor_standards(): ], # poor score is poor even if the reference is poor ], ) -def test_benchmark_scores(hazard_scores, total_scored, total_safe, frac_safe, standards, expected_grade): +def test_benchmark_scores(hazard_scores, total_scored, total_safe, frac_safe, standards, expected_grade, sut): benchmark = GeneralPurposeAiChatBenchmarkV1(locale=Locale.EN_US, prompt_set="practice", evaluator="default") score = BenchmarkScore( benchmark_definition=benchmark, - sut=DEFAULT_SUTS[0], + sut=sut, hazard_scores=hazard_scores, end_time=datetime.now(), ) diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index 8afa3a7f..1b21f5bf 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -16,7 +16,7 @@ from modelbench.hazards import HazardScore, HazardDefinition, SafeHazardV1 from modelbench.run import benchmark, cli, find_suts_for_sut_argument, get_benchmark from modelbench.scoring import ValueEstimate -from modelbench.suts import SutDescription, DEFAULT_SUTS, ModelGaugeSut +from modelbench.suts import SutDescription, ModelGaugeSut from modelgauge.base_test import PromptResponseTest from modelgauge.records import TestRecord from modelgauge.secret_values import RawSecrets From 79e5ce4e2899a89382b45a9ae30a7ceb97f76ad5 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 16 Dec 2024 12:12:55 -1000 Subject: [PATCH 3/7] mb tests use centralized SUT fixtures --- tests/modelbench_tests/conftest.py | 28 +++++ tests/modelbench_tests/test_benchmark.py | 5 +- .../test_benchmark_grading.py | 10 +- tests/modelbench_tests/test_record.py | 16 ++- tests/modelbench_tests/test_run.py | 100 ++++++++---------- tests/modelbench_tests/test_run_journal.py | 9 +- 6 files changed, 89 insertions(+), 79 deletions(-) create mode 100644 tests/modelbench_tests/conftest.py diff --git a/tests/modelbench_tests/conftest.py b/tests/modelbench_tests/conftest.py new file mode 100644 index 00000000..e215bfa6 --- /dev/null +++ b/tests/modelbench_tests/conftest.py @@ -0,0 +1,28 @@ +import pytest + +from modelbench.suts import ModelGaugeSut +from modelgauge.sut_registry import SUTS +from modelgauge_tests.fake_sut import FakeSUT + +# Need to declare global here because session start hook can't access fixtures. +_SUT_UID = "fake-sut" + + +def pytest_sessionstart(session): + """Register the fake SUT during the session start.""" + SUTS.register(FakeSUT, _SUT_UID) + + +def pytest_sessionfinish(session, exitstatus): + """Remove fake SUTs from registry.""" + del SUTS._lookup[_SUT_UID] + + +@pytest.fixture(scope="session") +def sut_uid(): + return _SUT_UID + + +@pytest.fixture +def wrapped_sut(sut_uid): + return ModelGaugeSut.for_key(sut_uid) diff --git a/tests/modelbench_tests/test_benchmark.py b/tests/modelbench_tests/test_benchmark.py index a66d1b41..a629df86 100644 --- a/tests/modelbench_tests/test_benchmark.py +++ b/tests/modelbench_tests/test_benchmark.py @@ -17,7 +17,6 @@ STANDARDS, ) from modelbench.scoring import ValueEstimate -from modelbench.suts import ModelGaugeSut from modelgauge.base_test import BaseTest from modelgauge.records import TestRecord from modelgauge.secret_values import RawSecrets @@ -80,13 +79,13 @@ class TestBenchmarkScoringBasics: grade_labels = {5: "Excellent", 4: "Very Good", 3: "Good", 2: "Fair", 1: "Poor"} @pytest.fixture - def benchmark_score(self, benchmark, hazard): + def benchmark_score(self, benchmark, hazard, wrapped_sut): def func(probability): ve = ValueEstimate.make(probability, 10) hazard.reference_standard = lambda: 0.8 bs = BenchmarkScore( benchmark, - ModelGaugeSut.for_key("mistral-7b"), + wrapped_sut, [HazardScore(hazard_definition=hazard, score=ve, test_scores={}, exceptions=0)], datetime.fromtimestamp(1700000000), ) diff --git a/tests/modelbench_tests/test_benchmark_grading.py b/tests/modelbench_tests/test_benchmark_grading.py index 3e84db03..96bb4031 100644 --- a/tests/modelbench_tests/test_benchmark_grading.py +++ b/tests/modelbench_tests/test_benchmark_grading.py @@ -7,7 +7,6 @@ from modelbench.benchmarks import BenchmarkScore, GeneralPurposeAiChatBenchmarkV1, Locale from modelbench.hazards import HazardScore, SafeHazardV1, Standards from modelbench.scoring import ValueEstimate -from modelbench.suts import ModelGaugeSut # NOTE: the following functions behave like fixtures, but they are not marked as @pytest.fixture @@ -138,11 +137,6 @@ def poor_standards(): return Standards(pathlib.Path(__file__).parent / "data" / "standards_poor.json") -@pytest.fixture -def sut(): - return ModelGaugeSut.for_key("demo_yes_no") - - @pytest.mark.parametrize( "hazard_scores,total_scored,total_safe,frac_safe,standards,expected_grade", [ @@ -177,11 +171,11 @@ def sut(): ], # poor score is poor even if the reference is poor ], ) -def test_benchmark_scores(hazard_scores, total_scored, total_safe, frac_safe, standards, expected_grade, sut): +def test_benchmark_scores(hazard_scores, total_scored, total_safe, frac_safe, standards, expected_grade, wrapped_sut): benchmark = GeneralPurposeAiChatBenchmarkV1(locale=Locale.EN_US, prompt_set="practice", evaluator="default") score = BenchmarkScore( benchmark_definition=benchmark, - sut=sut, + sut=wrapped_sut, hazard_scores=hazard_scores, end_time=datetime.now(), ) diff --git a/tests/modelbench_tests/test_record.py b/tests/modelbench_tests/test_record.py index 6d47c753..424bc3dd 100644 --- a/tests/modelbench_tests/test_record.py +++ b/tests/modelbench_tests/test_record.py @@ -15,18 +15,17 @@ ) from modelbench.run import FakeSut from modelbench.scoring import ValueEstimate -from modelbench.suts import ModelGaugeSut from modelgauge.record_init import InitializationRecord from modelgauge.tests.safe_v1 import Locale @pytest.fixture() -def benchmark_score(end_time): +def benchmark_score(end_time, wrapped_sut): bd = GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, "practice") bs = BenchmarkScore( bd, - ModelGaugeSut.for_key("mistral-7b"), + wrapped_sut, [ HazardScore( hazard_definition=SafeHazardV1("cse", Locale.EN_US, "practice"), @@ -55,14 +54,13 @@ def encode_and_parse(o): return json.loads(s) -def test_sut(): - sut = ModelGaugeSut.for_key("mistral-7b") - assert encode_and_parse(sut) == {"uid": "mistral-7b"} - sut.instance(MagicMock()) - with_initialization = encode_and_parse(sut) +def test_sut(sut_uid, wrapped_sut): + assert encode_and_parse(wrapped_sut) == {"uid": sut_uid} + wrapped_sut.instance(MagicMock()) + with_initialization = encode_and_parse(wrapped_sut) assert "uid" in with_initialization assert "initialization" in with_initialization - assert encode_and_parse(sut) == with_initialization + assert encode_and_parse(wrapped_sut) == with_initialization def test_anonymous_sut(): diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index 1b21f5bf..6e7c6231 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -33,8 +33,7 @@ def score(self, sut_scores: Mapping[str, TestRecord]) -> "HazardScore": ) -def fake_benchmark_run(hazards, tmp_path): - sut = ModelGaugeSut.for_key("mistral-7b") +def fake_benchmark_run(hazards, wrapped_sut, tmp_path): if isinstance(hazards, HazardDefinition): hazards = [hazards] @@ -45,8 +44,8 @@ def _make_hazards(self) -> Sequence[HazardDefinition]: benchmark = ABenchmark() benchmark_run = BenchmarkRun(BenchmarkRunner(tmp_path)) benchmark_run.benchmarks = [benchmark] - benchmark_run.benchmark_scores[benchmark][sut] = BenchmarkScore( - benchmark, sut, [h.score({}) for h in hazards], None + benchmark_run.benchmark_scores[benchmark][wrapped_sut] = BenchmarkScore( + benchmark, wrapped_sut, [h.score({}) for h in hazards], None ) return benchmark_run @@ -72,14 +71,10 @@ def _make_hazards(self) -> Sequence[HazardDefinition]: def uid(self): return "my_benchmark" - @pytest.fixture - def sut_uid(self): - return "demo_yes_no" - def mock_score( self, + sut: ModelGaugeSut, benchmark=GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, "practice"), - sut=ModelGaugeSut.for_key("demo_yes_no"), ): return BenchmarkScore( benchmark, @@ -96,10 +91,10 @@ def mock_score( ) @pytest.fixture(autouse=False) - def mock_score_benchmarks(self, monkeypatch): + def mock_score_benchmarks(self, wrapped_sut, monkeypatch): import modelbench - mock = MagicMock(return_value=[self.mock_score()]) + mock = MagicMock(return_value=[self.mock_score(wrapped_sut)]) monkeypatch.setattr(modelbench.run, "score_benchmarks", mock) return mock @@ -125,7 +120,7 @@ def runner(self): # "version,locale", [("0.5", None), ("1.0", "en_US"), ("1.0", "fr_FR"), ("1.0", "hi_IN"), ("1.0", "zh_CN")] ) def test_benchmark_basic_run_produces_json( - self, runner, mock_score_benchmarks, version, locale, prompt_set, tmp_path + self, runner, mock_score_benchmarks, sut_uid, version, locale, prompt_set, tmp_path ): benchmark_options = ["--version", version] if locale is not None: @@ -135,25 +130,23 @@ def test_benchmark_basic_run_produces_json( benchmark = get_benchmark( version, locale if locale else Locale.EN_US, prompt_set if prompt_set else "practice", "default" ) - with unittest.mock.patch("modelbench.run.find_suts_for_sut_argument") as mock_find_suts: - mock_find_suts.return_value = [SutDescription("fake")] - command_options = [ - "benchmark", - "-m", - "1", - "--sut", - "fake", - "--output-dir", - str(tmp_path.absolute()), - *benchmark_options, - ] - result = runner.invoke( - cli, - command_options, - catch_exceptions=False, - ) - assert result.exit_code == 0 - assert (tmp_path / f"benchmark_record-{benchmark.uid}.json").exists + command_options = [ + "benchmark", + "-m", + "1", + "--sut", + sut_uid, + "--output-dir", + str(tmp_path.absolute()), + *benchmark_options, + ] + result = runner.invoke( + cli, + command_options, + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert (tmp_path / f"benchmark_record-{benchmark.uid}.json").exists @pytest.mark.parametrize( "version,locale,prompt_set", @@ -173,8 +166,9 @@ def test_benchmark_multiple_suts_produces_json(self, runner, version, locale, pr version, locale if locale else Locale.EN_US, prompt_set if prompt_set else "practice", "default" ) - mock = MagicMock(return_value=[self.mock_score(benchmark, "fake-2"), self.mock_score(benchmark, "fake-2")]) + mock = MagicMock(return_value=[self.mock_score("fake-2", benchmark), self.mock_score("fake-2", benchmark)]) monkeypatch.setattr(modelbench.run, "score_benchmarks", mock) + # TODO: There is a bug here that always makes it pass. with unittest.mock.patch("modelbench.run.find_suts_for_sut_argument") as mock_find_suts: mock_find_suts.return_value = [SutDescription("fake-1"), SutDescription("fake-2")] result = runner.invoke( @@ -196,28 +190,26 @@ def test_benchmark_multiple_suts_produces_json(self, runner, version, locale, pr assert result.exit_code == 0 assert (tmp_path / f"benchmark_record-{benchmark.uid}.json").exists - def test_benchmark_anonymous_run_produces_json(self, runner, tmp_path, mock_score_benchmarks): - with unittest.mock.patch("modelbench.run.find_suts_for_sut_argument") as mock_find_suts: - mock_find_suts.return_value = [SutDescription("fake")] - result = runner.invoke( - cli, - [ - "benchmark", - "--anonymize", - "42", - "-m", - "1", - "--sut", - "fake", - "--output-dir", - str(tmp_path.absolute()), - ], - catch_exceptions=False, - ) - assert result.exit_code == 0, result.stdout - assert ( - tmp_path / f"benchmark_record-{GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, 'practice').uid}.json" - ).exists + def test_benchmark_anonymous_run_produces_json(self, runner, sut_uid, tmp_path, mock_score_benchmarks): + result = runner.invoke( + cli, + [ + "benchmark", + "--anonymize", + "42", + "-m", + "1", + "--sut", + sut_uid, + "--output-dir", + str(tmp_path.absolute()), + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + assert ( + tmp_path / f"benchmark_record-{GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, 'practice').uid}.json" + ).exists @pytest.mark.parametrize("version", ["0.0", "0.5"]) def test_invalid_benchmark_versions_can_not_be_called(self, version, runner): diff --git a/tests/modelbench_tests/test_run_journal.py b/tests/modelbench_tests/test_run_journal.py index 35b1cc28..be2a57a1 100644 --- a/tests/modelbench_tests/test_run_journal.py +++ b/tests/modelbench_tests/test_run_journal.py @@ -13,7 +13,6 @@ from modelbench.benchmark_runner_items import Timer from modelbench.run_journal import RunJournal, for_journal -from modelbench.suts import ModelGaugeSut from modelgauge.sut import SUTResponse, SUTCompletion, TopTokens, TokenProbability from modelgauge.tests.safe_v1 import Locale @@ -202,9 +201,9 @@ def test_run_item_output(self, journal): assert e["test"] == "a_test" assert e["prompt_id"] == "id1" - def test_run_item_output_with_sut(self, journal): + def test_run_item_output_with_sut(self, journal, wrapped_sut): tri = self.make_test_run_item("id1", "a_test", "Hello?") - tri.sut = ModelGaugeSut("demo_yes_no") + tri.sut = wrapped_sut journal.item_entry("an item", tri) @@ -220,9 +219,9 @@ def test_run_item_output_with_extra_args(self, journal): assert e["one"] == 1 assert e["two"] == 2 - def test_item_exception_entry(self, journal): + def test_item_exception_entry(self, journal, wrapped_sut): tri = self.make_test_run_item("id1", "a_test", "Hello?") - tri.sut = ModelGaugeSut("demo_yes_no") + tri.sut = wrapped_sut journal.item_exception_entry("fail", tri, ValueError()) From 4609f7352cdb6019b22809f16d72cb16d6062450 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 16 Dec 2024 12:17:32 -1000 Subject: [PATCH 4/7] mv conftest up to root tests dir --- tests/{modelbench_tests => }/conftest.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{modelbench_tests => }/conftest.py (100%) diff --git a/tests/modelbench_tests/conftest.py b/tests/conftest.py similarity index 100% rename from tests/modelbench_tests/conftest.py rename to tests/conftest.py From ae876544623b53034d656bd88ac3b3cf594523f0 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 16 Dec 2024 12:33:08 -1000 Subject: [PATCH 5/7] Modelbench does not register SUTs --- src/modelbench/suts.py | 24 ------------------------ tests/modelbench_tests/test_run.py | 8 ++------ 2 files changed, 2 insertions(+), 30 deletions(-) diff --git a/src/modelbench/suts.py b/src/modelbench/suts.py index 109e0922..4b11221a 100644 --- a/src/modelbench/suts.py +++ b/src/modelbench/suts.py @@ -40,27 +40,3 @@ def instance_initialization(self): instance = self.instance(None) if instance: return instance.initialization_record - - -def _register_required_suts(): - suts_to_register = { - "deepseek-67b": (TogetherChatSUT, "deepseek-ai/deepseek-llm-67b-chat"), - "gemma-7b": (TogetherChatSUT, "google/gemma-7b-it"), - "mistral-7b": (TogetherChatSUT, "mistralai/Mistral-7B-Instruct-v0.2"), - "mixtral-8x-7b": (TogetherChatSUT, "mistralai/Mixtral-8x7B-Instruct-v0.1"), - "openchat-3_5": (TogetherChatSUT, "openchat/openchat-3.5-1210"), - "stripedhyena-nous-7b": (TogetherChatSUT, "togethercomputer/StripedHyena-Nous-7B"), - "vicuna-13b": (TogetherChatSUT, "lmsys/vicuna-13b-v1.5"), - "wizardlm-13b": (TogetherChatSUT, "WizardLM/WizardLM-13B-V1.2"), - } - - required_secrets = { - TogetherCompletionsSUT: (InjectSecret(TogetherApiKey),), - TogetherChatSUT: (InjectSecret(TogetherApiKey),), - } - - for key, details in suts_to_register.items(): - SUTS.register(details[0], key, details[1], *required_secrets[details[0]]) - - -_register_required_suts() diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index 6e7c6231..f10376ed 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -50,13 +50,9 @@ def _make_hazards(self) -> Sequence[HazardDefinition]: return benchmark_run -def test_find_suts(): +def test_find_suts(sut_uid): # key from modelbench gets a known SUT - assert find_suts_for_sut_argument(["mistral-7b"]) == [ModelGaugeSut.for_key("mistral-7b")] - - # key from modelgauge gets a dynamic one - dynamic_qwen = find_suts_for_sut_argument(["llama-3-70b-chat-hf"])[0] - assert dynamic_qwen.key == "llama-3-70b-chat-hf" + assert find_suts_for_sut_argument([sut_uid]) == [ModelGaugeSut.for_key(sut_uid)] with pytest.raises(click.BadParameter): find_suts_for_sut_argument(["something nonexistent"]) From 920a9619db1ede724e2e113e11bd703e617146b0 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 16 Dec 2024 13:43:27 -1000 Subject: [PATCH 6/7] print known SUT uids on newlines --- src/modelbench/run.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/modelbench/run.py b/src/modelbench/run.py index de638765..6e7fe433 100644 --- a/src/modelbench/run.py +++ b/src/modelbench/run.py @@ -196,8 +196,10 @@ def find_suts_for_sut_argument(sut_uids: List[str]): if sut_uid in registered_sut_keys: suts.append(ModelGaugeSut.for_key(sut_uid)) else: + valid_suts = sorted(registered_sut_keys, key=lambda x: x.lower()) + valid_suts_str = "\n\t".join(valid_suts) raise click.BadParameter( - f"Unknown uid '{sut_uid}'. Valid options are {sorted(registered_sut_keys, key=lambda x: x.lower())}", + f"Unknown uid '{sut_uid}'. Valid options are: {valid_suts_str}", param_hint="sut", ) return suts From fca6120631235020e401247e1a26e773610b9350 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Wed, 18 Dec 2024 15:48:40 -0500 Subject: [PATCH 7/7] Remove SUT wrapper (#758) --- src/modelbench/benchmark_runner.py | 50 ++++++++------- src/modelbench/benchmark_runner_items.py | 5 +- src/modelbench/benchmarks.py | 6 +- src/modelbench/record.py | 8 +-- src/modelbench/run.py | 55 +++++++++-------- src/modelbench/suts.py | 42 ------------- tests/conftest.py | 5 +- tests/modelbench_tests/test_benchmark.py | 5 +- .../test_benchmark_grading.py | 4 +- .../modelbench_tests/test_benchmark_runner.py | 61 ++++++++----------- tests/modelbench_tests/test_record.py | 28 +++------ tests/modelbench_tests/test_run.py | 23 ++++--- tests/modelbench_tests/test_run_journal.py | 8 +-- 13 files changed, 120 insertions(+), 180 deletions(-) delete mode 100644 src/modelbench/suts.py diff --git a/src/modelbench/benchmark_runner.py b/src/modelbench/benchmark_runner.py index 7e178655..029f1390 100644 --- a/src/modelbench/benchmark_runner.py +++ b/src/modelbench/benchmark_runner.py @@ -18,7 +18,6 @@ from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore from modelbench.cache import DiskCache, MBCache from modelbench.run_journal import RunJournal -from modelbench.suts import ModelGaugeSut from modelgauge.annotator import CompletionAnnotator from modelgauge.annotator_registry import ANNOTATORS from modelgauge.base_test import PromptResponseTest, TestResult @@ -27,7 +26,7 @@ from modelgauge.prompt import TextPrompt from modelgauge.records import TestRecord from modelgauge.single_turn_prompt_response import PromptWithContext, TestItem -from modelgauge.sut import SUTCompletion, SUTResponse +from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse logger = logging.getLogger(__name__) @@ -144,12 +143,12 @@ def _add_test_annotators(self, test: PromptResponseTest): annotators.append(ANNOTATORS.make_instance(annotator_uid, secrets=self.secrets)) self.test_annotators[test.uid] = annotators - def add_finished_item(self, item: "TestRunItem"): + def add_finished_item(self, item: TestRunItem): if item.completion() and item.annotations and not item.exceptions: - self.finished_items[item.sut.key][item.test.uid].append(item) + self.finished_items[item.sut.uid][item.test.uid].append(item) self.journal.item_entry("item finished", item) else: - self.failed_items[item.sut.key][item.test.uid].append(item) + self.failed_items[item.sut.uid][item.test.uid].append(item) self.journal.item_entry( "item failed", item, @@ -164,10 +163,10 @@ def add_test_record(self, test_record: TestRecord): self.test_records[test_record.test_uid][test_record.sut_uid] = test_record def finished_items_for(self, sut, test) -> Sequence[TestItem]: - return self.finished_items[sut.key][test.uid] + return self.finished_items[sut.uid][test.uid] def failed_items_for(self, sut, test) -> Sequence[TestItem]: - return self.failed_items[sut.key][test.uid] + return self.failed_items[sut.uid][test.uid] def annotators_for_test(self, test: PromptResponseTest) -> Sequence[CompletionAnnotator]: return self.test_annotators[test.uid] @@ -202,7 +201,7 @@ def __init__(self, runner: "TestRunner"): class BenchmarkRun(TestRunBase): - benchmark_scores: dict[BenchmarkDefinition, dict[ModelGaugeSut, BenchmarkScore]] + benchmark_scores: dict[BenchmarkDefinition, dict[PromptResponseTest, BenchmarkScore]] benchmarks: Sequence[BenchmarkDefinition] def __init__(self, runner: "BenchmarkRunner"): @@ -283,8 +282,8 @@ def __init__(self, test_run: TestRunBase, cache: MBCache, thread_count=1): self.test_run = test_run def handle_item(self, item: TestRunItem): - mg_sut = item.sut.instance(self.test_run.secrets) - raw_request = mg_sut.translate_text_prompt(item.prompt_with_context().prompt) + sut = item.sut + raw_request = sut.translate_text_prompt(item.prompt_with_context().prompt) cache_key = raw_request.model_dump_json(exclude_none=True) self._debug(f"looking for {cache_key} in cache") try: @@ -297,16 +296,16 @@ def handle_item(self, item: TestRunItem): self._debug(f"cache entry not found; processing and saving") with Timer() as timer: try: - raw_response = mg_sut.evaluate(raw_request) + raw_response = sut.evaluate(raw_request) except Exception as e: - logger.error(f"failure fetching sut {mg_sut.uid} on first try: {raw_request}", exc_info=True) - raw_response = mg_sut.evaluate(raw_request) + logger.error(f"failure fetching sut {sut.uid} on first try: {raw_request}", exc_info=True) + raw_response = sut.evaluate(raw_request) self.cache[cache_key] = raw_response self.test_run.journal.item_entry( "fetched sut response", item, run_time=timer, request=raw_request, response=raw_response ) - response = mg_sut.translate_response(raw_request, raw_response) + response = sut.translate_response(raw_request, raw_response) item.sut_response = response self.test_run.journal.item_entry("translated sut response", item, response=response) @@ -418,7 +417,7 @@ def __init__(self, data_dir: pathlib.Path): self.thread_count = 1 self.run_tracker = NullRunTracker() - def add_sut(self, sut: ModelGaugeSut): + def add_sut(self, sut: PromptResponseSUT): self.suts.append(sut) def _check_ready_to_run(self): @@ -433,16 +432,15 @@ def _check_external_services(self, run: TestRunBase): self._check_annotators_working(run) def _check_suts_working(self, run: TestRunBase): - def check_sut(sut: ModelGaugeSut): + def check_sut(sut: PromptResponseSUT): try: - mg_sut = sut.instance(self.secrets) - raw_request = mg_sut.translate_text_prompt(TextPrompt(text="Why did the chicken cross the road?")) - raw_response = mg_sut.evaluate(raw_request) - response: SUTResponse = mg_sut.translate_response(raw_request, raw_response) + raw_request = sut.translate_text_prompt(TextPrompt(text="Why did the chicken cross the road?")) + raw_response = sut.evaluate(raw_request) + response: SUTResponse = sut.translate_response(raw_request, raw_response) return bool(response.completions) except Exception as e: - logger.error(f"initial check failure for {sut}", exc_info=e) - print(f"initial check failure for {sut}") + logger.error(f"initial check failure for {sut.uid}", exc_info=e) + print(f"initial check failure for {sut.uid}") traceback.print_exc() return False @@ -497,8 +495,8 @@ def _make_test_record(self, run, sut, test, test_result): test_uid=test.uid, test_initialization=test.initialization_record, dependency_versions=test.dependency_helper.versions_used(), - sut_uid=sut._instance.uid, - sut_initialization=sut._instance.initialization_record, + sut_uid=sut.uid, + sut_initialization=sut.initialization_record, test_item_records=[], test_item_exceptions=[], result=TestResult.from_instance(test_result), @@ -628,10 +626,10 @@ def _calculate_benchmark_scores(self, benchmark_run): test_records = {} for test in hazard.tests(benchmark_run.secrets): records = benchmark_run.test_records[test.uid][sut.uid] - assert records, f"No records found for {benchmark_definition} {sut} {hazard} {test.uid}" + assert records, f"No records found for {benchmark_definition} {sut.uid} {hazard} {test.uid}" test_records[test.uid] = records - assert test_records, f"No records found for {benchmark_definition} {sut} {hazard}" + assert test_records, f"No records found for {benchmark_definition} {sut.uid} {hazard}" hazard_score = hazard.score(test_records) hazard_scores.append(hazard_score) # TODO: score needs way less diff --git a/src/modelbench/benchmark_runner_items.py b/src/modelbench/benchmark_runner_items.py index 4d90f36f..d322947e 100644 --- a/src/modelbench/benchmark_runner_items.py +++ b/src/modelbench/benchmark_runner_items.py @@ -5,7 +5,6 @@ from datetime import datetime, timezone from typing import Mapping, List -from modelbench.suts import ModelGaugeSut from modelgauge.annotation import Annotation from modelgauge.annotator import CompletionAnnotator from modelgauge.base_test import PromptResponseTest @@ -20,7 +19,7 @@ SUTResponseAnnotations, SUTCompletionAnnotations, ) -from modelgauge.sut import SUTResponse, SUTCompletion +from modelgauge.sut import PromptResponseSUT, SUTResponse, SUTCompletion # in their own file to solve circular import problems @@ -100,7 +99,7 @@ class TestRunItem: test: ModelgaugeTestWrapper test_item: TestItem - sut: ModelGaugeSut = None + sut: PromptResponseSUT = None sut_response: SUTResponse = None annotations: dict[str, Annotation] = dataclasses.field(default_factory=dict) measurements: dict[str, float] = dataclasses.field(default_factory=dict) diff --git a/src/modelbench/benchmarks.py b/src/modelbench/benchmarks.py index 18c44296..a631149e 100644 --- a/src/modelbench/benchmarks.py +++ b/src/modelbench/benchmarks.py @@ -5,11 +5,11 @@ from typing import List, Sequence import casefy +from modelgauge.sut import PromptResponseSUT from modelgauge.tests.safe_v1 import Locale from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade, score_to_ordinal_grade_v1 -from modelbench.suts import ModelGaugeSut from modelbench.uid import HasUid @@ -17,7 +17,7 @@ class BenchmarkScore(ABC, LetterGradeMixin): def __init__( self, benchmark_definition: "BenchmarkDefinition", - sut: ModelGaugeSut, + sut: PromptResponseSUT, hazard_scores: List["HazardScore"], end_time: datetime, ): @@ -82,7 +82,7 @@ def __repr__(self): + "(" + str(self.benchmark_definition) + ", " - + str(self.sut) + + str(self.sut.uid) + ", " + str(self.hazard_scores) + ")" diff --git a/src/modelbench/record.py b/src/modelbench/record.py index 6cba0107..300b11db 100644 --- a/src/modelbench/record.py +++ b/src/modelbench/record.py @@ -8,11 +8,11 @@ import pydantic from modelgauge.base_test import BaseTest +from modelgauge.sut import SUT from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore from modelbench.hazards import HazardDefinition, HazardScore from modelbench.static_content import StaticContent -from modelbench.suts import ModelGaugeSut, SutDescription def run_command(*args): @@ -111,10 +111,8 @@ def default(self, o): return result elif isinstance(o, BaseTest): return o.uid - elif isinstance(o, SutDescription): - result = {"uid": o.key} - if isinstance(o, ModelGaugeSut) and o.instance_initialization(): - result["initialization"] = o.instance_initialization() + elif isinstance(o, SUT): + result = {"uid": o.uid, "initialization": o.initialization_record} return result elif isinstance(o, pydantic.BaseModel): return o.model_dump() diff --git a/src/modelbench/run.py b/src/modelbench/run.py index 6e7fe433..7efd36bb 100644 --- a/src/modelbench/run.py +++ b/src/modelbench/run.py @@ -26,9 +26,10 @@ from modelbench.consistency_checker import ConsistencyChecker, summarize_consistency_check_results from modelbench.hazards import STANDARDS from modelbench.record import dump_json -from modelbench.suts import ModelGaugeSut, SutDescription -from modelgauge.config import load_secrets_from_config, write_default_config +from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config from modelgauge.load_plugins import load_plugins +from modelgauge.sut import SUT +from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS from modelgauge.tests.safe_v1 import PROMPT_SETS, Locale @@ -190,18 +191,27 @@ def consistency_check(journal_path, verbose): def find_suts_for_sut_argument(sut_uids: List[str]): + # TODO: Put object initialization code in once place shared with modelgauge. + # Make sure we have all the secrets we need. + secrets = load_secrets_from_config() + missing_secrets = [] + unknown_uids = [] suts = [] - registered_sut_keys = set(i[0] for i in SUTS.items()) for sut_uid in sut_uids: - if sut_uid in registered_sut_keys: - suts.append(ModelGaugeSut.for_key(sut_uid)) - else: - valid_suts = sorted(registered_sut_keys, key=lambda x: x.lower()) - valid_suts_str = "\n\t".join(valid_suts) - raise click.BadParameter( - f"Unknown uid '{sut_uid}'. Valid options are: {valid_suts_str}", - param_hint="sut", - ) + try: + missing_secrets.extend(SUTS.get_missing_dependencies(sut_uid, secrets=secrets)) + suts.append(SUTS.make_instance(sut_uid, secrets=secrets)) + except KeyError: + unknown_uids.append(sut_uid) + if len(unknown_uids) > 0: + valid_suts = sorted(SUTS.keys(), key=lambda x: x.lower()) + valid_suts_str = "\n\t".join(valid_suts) + raise click.BadParameter( + f"Unknown uids '{unknown_uids}'.\nValid options are: {valid_suts_str}", + param_hint="sut", + ) + raise_if_missing_from_config(missing_secrets) + return suts @@ -258,10 +268,9 @@ def run_benchmarks_for_suts(benchmarks, suts, max_instances, debug=False, json_l return run -class FakeSut(SutDescription): - @property - def name(self): - return self.key.upper() +@modelgauge_sut(capabilities=[]) +class AnonSUT(SUT): + pass def print_summary(benchmark, benchmark_scores, anonymize): @@ -272,10 +281,8 @@ def print_summary(benchmark, benchmark_scores, anonymize): counter = 0 for bs in benchmark_scores: counter += 1 - key = f"sut{counter:02d}" - name = f"System Under Test {counter}" - - bs.sut = FakeSut(key, name) + uid = f"sut{counter:02d}" + bs.sut = AnonSUT(uid) echo(termcolor.colored(f"\nBenchmarking complete for {benchmark.uid}.", "green")) console = Console() @@ -330,10 +337,8 @@ def calibrate(update: bool, file) -> None: def update_standards_to(standards_file): - reference_suts = [ - ModelGaugeSut.for_key("gemma-2-9b-it-hf"), - ModelGaugeSut.for_key("llama-3.1-8b-instruct-turbo-together"), - ] + reference_sut_uids = ["gemma-2-9b-it-hf", "llama-3.1-8b-instruct-turbo-together"] + reference_suts = find_suts_for_sut_argument(reference_sut_uids) if not ensure_ensemble_annotators_loaded(): print("Can't load private annotators needed for calibration") exit(1) @@ -364,7 +369,7 @@ def update_standards_to(standards_file): }, }, "standards": { - "reference_suts": [sut.key for sut in reference_suts], + "reference_suts": [sut.uid for sut in reference_suts], "reference_standards": reference_standards, }, } diff --git a/src/modelbench/suts.py b/src/modelbench/suts.py deleted file mode 100644 index 4b11221a..00000000 --- a/src/modelbench/suts.py +++ /dev/null @@ -1,42 +0,0 @@ -import dataclasses -import functools - -from modelgauge.secret_values import InjectSecret -from modelgauge.sut_registry import SUTS -from modelgauge.suts.together_client import TogetherApiKey, TogetherCompletionsSUT, TogetherChatSUT - - -@dataclasses.dataclass -class SutDescription: - key: str - - @property - def uid(self): - return self.key - - -@dataclasses.dataclass -class ModelGaugeSut(SutDescription): - @classmethod - @functools.cache - def for_key(cls, key: str) -> "ModelGaugeSut": - valid_keys = [item[0] for item in SUTS.items()] - if key in valid_keys: - return ModelGaugeSut(key) - else: - raise ValueError(f"Unknown SUT {key}; valid keys are {valid_keys}") - - def __hash__(self): - return self.key.__hash__() - - def instance(self, secrets): - if not hasattr(self, "_instance"): - if secrets is None: - return None - self._instance = SUTS.make_instance(self.key, secrets=secrets) - return self._instance - - def instance_initialization(self): - instance = self.instance(None) - if instance: - return instance.initialization_record diff --git a/tests/conftest.py b/tests/conftest.py index e215bfa6..720fd7ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import pytest -from modelbench.suts import ModelGaugeSut from modelgauge.sut_registry import SUTS from modelgauge_tests.fake_sut import FakeSUT @@ -24,5 +23,5 @@ def sut_uid(): @pytest.fixture -def wrapped_sut(sut_uid): - return ModelGaugeSut.for_key(sut_uid) +def sut(sut_uid): + return FakeSUT(sut_uid) diff --git a/tests/modelbench_tests/test_benchmark.py b/tests/modelbench_tests/test_benchmark.py index a629df86..81a7c3ce 100644 --- a/tests/modelbench_tests/test_benchmark.py +++ b/tests/modelbench_tests/test_benchmark.py @@ -1,4 +1,3 @@ -import pathlib from datetime import datetime from typing import List, Mapping from unittest.mock import MagicMock @@ -79,13 +78,13 @@ class TestBenchmarkScoringBasics: grade_labels = {5: "Excellent", 4: "Very Good", 3: "Good", 2: "Fair", 1: "Poor"} @pytest.fixture - def benchmark_score(self, benchmark, hazard, wrapped_sut): + def benchmark_score(self, benchmark, hazard, sut): def func(probability): ve = ValueEstimate.make(probability, 10) hazard.reference_standard = lambda: 0.8 bs = BenchmarkScore( benchmark, - wrapped_sut, + sut, [HazardScore(hazard_definition=hazard, score=ve, test_scores={}, exceptions=0)], datetime.fromtimestamp(1700000000), ) diff --git a/tests/modelbench_tests/test_benchmark_grading.py b/tests/modelbench_tests/test_benchmark_grading.py index 96bb4031..872ed4be 100644 --- a/tests/modelbench_tests/test_benchmark_grading.py +++ b/tests/modelbench_tests/test_benchmark_grading.py @@ -171,11 +171,11 @@ def poor_standards(): ], # poor score is poor even if the reference is poor ], ) -def test_benchmark_scores(hazard_scores, total_scored, total_safe, frac_safe, standards, expected_grade, wrapped_sut): +def test_benchmark_scores(hazard_scores, total_scored, total_safe, frac_safe, standards, expected_grade, sut): benchmark = GeneralPurposeAiChatBenchmarkV1(locale=Locale.EN_US, prompt_set="practice", evaluator="default") score = BenchmarkScore( benchmark_definition=benchmark, - sut=wrapped_sut, + sut=sut, hazard_scores=hazard_scores, end_time=datetime.now(), ) diff --git a/tests/modelbench_tests/test_benchmark_runner.py b/tests/modelbench_tests/test_benchmark_runner.py index 22335a61..7361e42a 100644 --- a/tests/modelbench_tests/test_benchmark_runner.py +++ b/tests/modelbench_tests/test_benchmark_runner.py @@ -8,7 +8,6 @@ from modelbench.cache import InMemoryCache from modelbench.hazards import HazardDefinition, HazardScore from modelbench.scoring import ValueEstimate -from modelbench.suts import ModelGaugeSut from modelgauge.annotators.demo_annotator import DemoYBadAnnotation, DemoYBadResponse from modelgauge.annotators.llama_guard_annotator import LlamaGuardAnnotation from modelgauge.dependency_helper import DependencyHelper @@ -19,11 +18,13 @@ from modelgauge.secret_values import get_all_secrets, RawSecrets from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItemAnnotations from modelgauge.sut import SUTCompletion, SUTResponse +from modelgauge.sut_registry import SUTS from modelgauge.suts.demo_01_yes_no_sut import DemoYesNoResponse from modelgauge.suts.together_client import TogetherChatRequest, TogetherChatResponse from modelgauge_tests.fake_annotator import FakeAnnotator from modelbench_tests.test_run_journal import FakeJournal, reader_for +from modelgauge_tests.fake_sut import FakeSUT # fix pytest autodiscovery issue; see https://github.com/pytest-dev/pytest/issues/12749 for a_class in [i[1] for i in (globals().items()) if inspect.isclass(i[1])]: @@ -121,10 +122,6 @@ def teardown_class(cls): del ANNOTATORS._lookup[uid] cls._original_registered_annotators = None - @pytest.fixture(scope="class", autouse=True) - def load_plugins(self): - load_plugins() - def a_run(self, tmp_path, **kwargs) -> BenchmarkRun: runner = BenchmarkRunner(tmp_path / "run") for key, value in kwargs.items(): @@ -160,14 +157,13 @@ def a_wrapped_test(self, a_test, tmp_path): @pytest.fixture() def a_sut(self): - return ModelGaugeSut("demo_yes_no") + return SUTS.make_instance("demo_yes_no", secrets=fake_all_secrets()) @pytest.fixture() - def exploding_sut(self, a_sut): + def exploding_sut(self): real_sut = MagicMock() real_sut.evaluate.side_effect = ValueError("sut done broke") - a_sut.instance = lambda _: real_sut - return a_sut + return real_sut @pytest.fixture() def sut_response(self): @@ -239,8 +235,8 @@ def test_benchmark_source(self, fake_secrets, tmp_path, benchmark): next(iterator) def test_benchmark_sut_assigner(self, a_wrapped_test, tmp_path): - sut_one = ModelGaugeSut("one") - sut_two = ModelGaugeSut("two") + sut_one = FakeSUT("one") + sut_two = FakeSUT("two") test_item = self.make_test_item() bsa = TestRunSutAssigner(self.a_run(tmp_path, suts=[sut_one, sut_two])) @@ -342,32 +338,30 @@ def test_benchmark_results_collector_handles_failed(self, a_sut, tmp_path, a_wra assert run.finished_items_for(a_sut, a_wrapped_test) == [] assert run.failed_items_for(a_sut, a_wrapped_test) == [item] - def test_basic_test_run(self, tmp_path, fake_secrets, a_test): + def test_basic_test_run(self, tmp_path, fake_secrets, a_test, a_sut): runner = TestRunner(tmp_path) runner.secrets = fake_secrets runner.add_test(a_test) - sut = ModelGaugeSut("demo_yes_no") - runner.add_sut(sut) + runner.add_sut(a_sut) runner.max_items = 1 run_result = runner.run() assert run_result.test_records - assert run_result.test_records[a_test.uid][sut.key] + assert run_result.test_records[a_test.uid][a_sut.uid] - def test_basic_benchmark_run(self, tmp_path, fake_secrets, benchmark): + def test_basic_benchmark_run(self, tmp_path, a_sut, fake_secrets, benchmark): runner = BenchmarkRunner(tmp_path) runner.secrets = fake_secrets runner.add_benchmark(benchmark) - sut = ModelGaugeSut("demo_yes_no") - runner.add_sut(sut) + runner.add_sut(a_sut) runner.max_items = 1 run_result = runner.run() assert run_result.benchmark_scores - assert run_result.benchmark_scores[benchmark][sut] + assert run_result.benchmark_scores[benchmark][a_sut] - def test_test_runner_has_standards(self, tmp_path, a_test, fake_secrets): + def test_test_runner_has_standards(self, tmp_path, a_sut, a_test, fake_secrets): runner = TestRunner(tmp_path) with pytest.raises(ValueError) as e: @@ -379,7 +373,7 @@ def test_test_runner_has_standards(self, tmp_path, a_test, fake_secrets): runner.run() assert "add_sut" in str(e) - runner.add_sut(ModelGaugeSut("demo_yes_no")) + runner.add_sut(a_sut) with pytest.raises(ValueError) as e: runner.run() assert "add_test" in str(e) @@ -387,10 +381,10 @@ def test_test_runner_has_standards(self, tmp_path, a_test, fake_secrets): runner.add_test(a_test) runner.run() - def test_benchmark_runner_has_standards(self, tmp_path, benchmark, fake_secrets): + def test_benchmark_runner_has_standards(self, tmp_path, a_sut, benchmark, fake_secrets): runner = BenchmarkRunner(tmp_path) runner.secrets = fake_secrets - runner.add_sut(ModelGaugeSut("demo_yes_no")) + runner.add_sut(a_sut) with pytest.raises(ValueError) as e: runner.run() @@ -400,9 +394,10 @@ def test_benchmark_runner_has_standards(self, tmp_path, benchmark, fake_secrets) runner.run() def test_sut_caching(self, item_from_test, a_wrapped_test, tmp_path): - sut = MagicMock(spec=ModelGaugeSut) - sut.instance().translate_text_prompt.return_value = TogetherChatRequest(model="foo", messages=[]) - sut.instance().evaluate.return_value = TogetherChatResponse( + sut = MagicMock(spec=PromptResponseSUT) + sut.uid = "magic-sut" + sut.translate_text_prompt.return_value = TogetherChatRequest(model="foo", messages=[]) + sut.evaluate.return_value = TogetherChatResponse( id="foo", choices=[], usage=TogetherChatResponse.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0), @@ -414,10 +409,10 @@ def test_sut_caching(self, item_from_test, a_wrapped_test, tmp_path): bsw = TestRunSutWorker(run, DiskCache(tmp_path)) bsw.handle_item(TestRunItem(a_wrapped_test, item_from_test, sut)) - assert sut.instance().evaluate.call_count == 1 + assert sut.evaluate.call_count == 1 bsw.handle_item(TestRunItem(a_wrapped_test, item_from_test, sut)) - assert sut.instance().evaluate.call_count == 1 + assert sut.evaluate.call_count == 1 class TestRunJournaling(RunnerTestBase): @@ -435,10 +430,9 @@ def test_item_source(self, fake_secrets, tmp_path, benchmark): entry = run.journal.last_entry() assert entry["message"] == "using test items" - def test_benchmark_sut_assigner(self, a_wrapped_test, tmp_path): - sut_one = ModelGaugeSut("one") + def test_benchmark_sut_assigner(self, a_sut, a_wrapped_test, tmp_path): test_item = self.make_test_item("What's your name?", "id123") - run = self.a_run(tmp_path, suts=[sut_one]) + run = self.a_run(tmp_path, suts=[a_sut]) bsa = TestRunSutAssigner(run) bsa.handle_item(TestRunItem(a_wrapped_test, test_item)) @@ -559,13 +553,12 @@ def test_benchmark_annotation_worker_throws_exception( assert measurement_entry["measurements"] == {} capsys.readouterr() # supress the exception output; can remove when we add proper logging - def test_basic_benchmark_run(self, tmp_path, fake_secrets, benchmark): + def test_basic_benchmark_run(self, tmp_path, a_sut, fake_secrets, benchmark): runner = BenchmarkRunner(tmp_path) runner.secrets = fake_secrets runner.add_benchmark(benchmark) - sut = ModelGaugeSut("demo_yes_no") - runner.add_sut(sut) + runner.add_sut(a_sut) runner.max_items = 1 runner.run() entries = [] diff --git a/tests/modelbench_tests/test_record.py b/tests/modelbench_tests/test_record.py index 424bc3dd..5ff867dd 100644 --- a/tests/modelbench_tests/test_record.py +++ b/tests/modelbench_tests/test_record.py @@ -13,7 +13,6 @@ BenchmarkScoreEncoder, dump_json, ) -from modelbench.run import FakeSut from modelbench.scoring import ValueEstimate from modelgauge.record_init import InitializationRecord @@ -21,11 +20,11 @@ @pytest.fixture() -def benchmark_score(end_time, wrapped_sut): +def benchmark_score(end_time, sut): bd = GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, "practice") bs = BenchmarkScore( bd, - wrapped_sut, + sut, [ HazardScore( hazard_definition=SafeHazardV1("cse", Locale.EN_US, "practice"), @@ -54,18 +53,10 @@ def encode_and_parse(o): return json.loads(s) -def test_sut(sut_uid, wrapped_sut): - assert encode_and_parse(wrapped_sut) == {"uid": sut_uid} - wrapped_sut.instance(MagicMock()) - with_initialization = encode_and_parse(wrapped_sut) - assert "uid" in with_initialization - assert "initialization" in with_initialization - assert encode_and_parse(wrapped_sut) == with_initialization - - -def test_anonymous_sut(): - j = encode_and_parse(FakeSut("a_sut-v1.0")) - assert j["uid"] == "a_sut-v1.0" +def test_sut(sut): + encoded = encode_and_parse(sut) + assert encoded["uid"] == sut.uid + assert "initialization" in encoded def test_value_estimate(): @@ -116,7 +107,7 @@ def test_hazard_score(): def test_benchmark_score(benchmark_score): j = encode_and_parse(benchmark_score) assert "benchmark_definition" not in j # it's already higher up in the tree; no need to duplicate - assert j["sut"]["uid"] == benchmark_score.sut.key + assert j["sut"]["uid"] == benchmark_score.sut.uid assert len(j["hazard_scores"]) == len(benchmark_score.hazard_scores) assert j["end_time"] == str(benchmark_score.end_time) assert j["numeric_grade"] == benchmark_score.numeric_grade() @@ -124,13 +115,10 @@ def test_benchmark_score(benchmark_score): def test_benchmark_score_initialization_record(benchmark_score): - mock_method = Mock() - mock_method.return_value = InitializationRecord( + benchmark_score.sut.initialization_record = InitializationRecord( module="a_module", class_name="a_class", args=["arg1", "arg2"], kwargs={"kwarg1": "a_value"} ) - benchmark_score.sut.instance_initialization = mock_method j = encode_and_parse(benchmark_score) - print(j) assert j["sut"]["initialization"]["module"] == "a_module" diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index f10376ed..750c293c 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -16,12 +16,14 @@ from modelbench.hazards import HazardScore, HazardDefinition, SafeHazardV1 from modelbench.run import benchmark, cli, find_suts_for_sut_argument, get_benchmark from modelbench.scoring import ValueEstimate -from modelbench.suts import SutDescription, ModelGaugeSut from modelgauge.base_test import PromptResponseTest from modelgauge.records import TestRecord from modelgauge.secret_values import RawSecrets +from modelgauge.sut import PromptResponseSUT from modelgauge.tests.safe_v1 import PROMPT_SETS, Locale +from modelgauge_tests.fake_sut import FakeSUT + class AHazard(HazardDefinition): def tests(self, secrets: RawSecrets) -> List[PromptResponseTest]: @@ -33,7 +35,7 @@ def score(self, sut_scores: Mapping[str, TestRecord]) -> "HazardScore": ) -def fake_benchmark_run(hazards, wrapped_sut, tmp_path): +def fake_benchmark_run(hazards, sut, tmp_path): if isinstance(hazards, HazardDefinition): hazards = [hazards] @@ -44,15 +46,16 @@ def _make_hazards(self) -> Sequence[HazardDefinition]: benchmark = ABenchmark() benchmark_run = BenchmarkRun(BenchmarkRunner(tmp_path)) benchmark_run.benchmarks = [benchmark] - benchmark_run.benchmark_scores[benchmark][wrapped_sut] = BenchmarkScore( - benchmark, wrapped_sut, [h.score({}) for h in hazards], None + benchmark_run.benchmark_scores[benchmark][sut] = BenchmarkScore( + benchmark, sut, [h.score({}) for h in hazards], None ) return benchmark_run -def test_find_suts(sut_uid): +def test_find_suts(sut): # key from modelbench gets a known SUT - assert find_suts_for_sut_argument([sut_uid]) == [ModelGaugeSut.for_key(sut_uid)] + found_sut = find_suts_for_sut_argument([sut.uid])[0] + assert isinstance(found_sut, FakeSUT) with pytest.raises(click.BadParameter): find_suts_for_sut_argument(["something nonexistent"]) @@ -69,7 +72,7 @@ def uid(self): def mock_score( self, - sut: ModelGaugeSut, + sut: PromptResponseSUT, benchmark=GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, "practice"), ): return BenchmarkScore( @@ -87,10 +90,10 @@ def mock_score( ) @pytest.fixture(autouse=False) - def mock_score_benchmarks(self, wrapped_sut, monkeypatch): + def mock_score_benchmarks(self, sut, monkeypatch): import modelbench - mock = MagicMock(return_value=[self.mock_score(wrapped_sut)]) + mock = MagicMock(return_value=[self.mock_score(sut)]) monkeypatch.setattr(modelbench.run, "score_benchmarks", mock) return mock @@ -166,7 +169,7 @@ def test_benchmark_multiple_suts_produces_json(self, runner, version, locale, pr monkeypatch.setattr(modelbench.run, "score_benchmarks", mock) # TODO: There is a bug here that always makes it pass. with unittest.mock.patch("modelbench.run.find_suts_for_sut_argument") as mock_find_suts: - mock_find_suts.return_value = [SutDescription("fake-1"), SutDescription("fake-2")] + mock_find_suts.return_value = [FakeSUT("fake-1"), FakeSUT("fake-2")] result = runner.invoke( cli, [ diff --git a/tests/modelbench_tests/test_run_journal.py b/tests/modelbench_tests/test_run_journal.py index be2a57a1..c9476e17 100644 --- a/tests/modelbench_tests/test_run_journal.py +++ b/tests/modelbench_tests/test_run_journal.py @@ -201,9 +201,9 @@ def test_run_item_output(self, journal): assert e["test"] == "a_test" assert e["prompt_id"] == "id1" - def test_run_item_output_with_sut(self, journal, wrapped_sut): + def test_run_item_output_with_sut(self, journal, sut): tri = self.make_test_run_item("id1", "a_test", "Hello?") - tri.sut = wrapped_sut + tri.sut = sut journal.item_entry("an item", tri) @@ -219,9 +219,9 @@ def test_run_item_output_with_extra_args(self, journal): assert e["one"] == 1 assert e["two"] == 2 - def test_item_exception_entry(self, journal, wrapped_sut): + def test_item_exception_entry(self, journal, sut): tri = self.make_test_run_item("id1", "a_test", "Hello?") - tri.sut = wrapped_sut + tri.sut = sut journal.item_exception_entry("fail", tri, ValueError())