Skip to content

Commit

Permalink
Make sure to properly close the journal even when things fail (#817)
Browse files Browse the repository at this point in the history
* Make sure to properly close the journal even when things fail.

* missed a test
  • Loading branch information
wpietri authored Jan 24, 2025
1 parent 9bacf71 commit 6046471
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 53 deletions.
109 changes: 58 additions & 51 deletions src/modelbench/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from multiprocessing.pool import ThreadPool
from typing import Any, Iterable, Optional, Sequence

from pydantic import BaseModel
from tqdm import tqdm

from modelbench.benchmark_runner_items import ModelgaugeTestWrapper, TestRunItem, Timer
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore
from modelbench.cache import DiskCache, MBCache
Expand All @@ -27,6 +24,8 @@
from modelgauge.records import TestRecord
from modelgauge.single_turn_prompt_response import PromptWithContext, TestItem
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from pydantic import BaseModel
from tqdm import tqdm

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -189,6 +188,15 @@ def cache_info(self):
result.append(f" {key}: finished with {len(self.caches[key])}")
return "\n".join(result)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self.journal.raw_entry("exception stopping run", exc_type=str(exc_type), exc_val=exc_val)
self.journal.raw_entry("closing journal")
self.journal.close()


class TestRun(TestRunBase):
tests: list[ModelgaugeTestWrapper]
Expand Down Expand Up @@ -570,60 +578,59 @@ def _check_ready_to_run(self):
def run(self) -> BenchmarkRun:
self._check_ready_to_run()

benchmark_run = BenchmarkRun(self)
self._check_external_services(benchmark_run)
benchmark_run.journal.raw_entry(
"starting run",
run_id=benchmark_run.run_id,
benchmarks=[b.uid for b in benchmark_run.benchmarks],
tests=[t.uid for t in benchmark_run.tests],
suts=[s.uid for s in benchmark_run.suts],
max_items=benchmark_run.max_items,
thread_count=self.thread_count,
)
for test in benchmark_run.tests:
with BenchmarkRun(self) as benchmark_run:
self._check_external_services(benchmark_run)
benchmark_run.journal.raw_entry(
"test info",
test=test.uid,
initialization=test.initialization_record,
sut_options=test.sut_options(),
dependencies=test.dependencies(),
"starting run",
run_id=benchmark_run.run_id,
benchmarks=[b.uid for b in benchmark_run.benchmarks],
tests=[t.uid for t in benchmark_run.tests],
suts=[s.uid for s in benchmark_run.suts],
max_items=benchmark_run.max_items,
thread_count=self.thread_count,
)
pipeline = self._build_pipeline(benchmark_run)
benchmark_run.run_tracker.start(self._expected_item_count(benchmark_run, pipeline))
benchmark_run.journal.raw_entry("running pipeline")
with Timer() as timer:
pipeline.run()

total_items_finished = 0
finished_item_counts = defaultdict(dict)
for k1, d1 in benchmark_run.finished_items.items():
for k2, l1 in d1.items():
total_items_finished += len(d1)
finished_item_counts[k1][k2] = len(d1)

benchmark_run.journal.raw_entry(
"finished pipeline",
time=timer.elapsed,
total_finished=total_items_finished,
finished_counts=finished_item_counts,
)
for test in benchmark_run.tests:
benchmark_run.journal.raw_entry(
"test info",
test=test.uid,
initialization=test.initialization_record,
sut_options=test.sut_options(),
dependencies=test.dependencies(),
)
pipeline = self._build_pipeline(benchmark_run)
benchmark_run.run_tracker.start(self._expected_item_count(benchmark_run, pipeline))
benchmark_run.journal.raw_entry("running pipeline")
with Timer() as timer:
pipeline.run()

total_items_finished = 0
finished_item_counts = defaultdict(dict)
for k1, d1 in benchmark_run.finished_items.items():
for k2, l1 in d1.items():
total_items_finished += len(d1)
finished_item_counts[k1][k2] = len(d1)

self._calculate_test_results(benchmark_run)
self._calculate_benchmark_scores(benchmark_run)
benchmark_run.run_tracker.done()
benchmark_run.journal.raw_entry("finished run", run_id=benchmark_run.run_id)
for key, cache in benchmark_run.caches.items():
cache = benchmark_run.caches[key]
benchmark_run.journal.raw_entry(
"cache info",
type=key,
cache=str(cache),
start_count=benchmark_run.cache_starting_size[key],
end_count=len(cache),
"finished pipeline",
time=timer.elapsed,
total_finished=total_items_finished,
finished_counts=finished_item_counts,
)

benchmark_run.journal.close()
self._calculate_test_results(benchmark_run)
self._calculate_benchmark_scores(benchmark_run)
benchmark_run.run_tracker.done()
benchmark_run.journal.raw_entry("finished run", run_id=benchmark_run.run_id)
for key, cache in benchmark_run.caches.items():
cache = benchmark_run.caches[key]
benchmark_run.journal.raw_entry(
"cache info",
type=key,
cache=str(cache),
start_count=benchmark_run.cache_starting_size[key],
end_count=len(cache),
)

return benchmark_run

def _calculate_benchmark_scores(self, benchmark_run):
Expand Down
4 changes: 2 additions & 2 deletions tests/modelbench_tests/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from unittest.mock import MagicMock

import pytest

from modelbench.benchmark_runner import *
from modelbench.cache import InMemoryCache
from modelbench.hazards import HazardDefinition, HazardScore
Expand All @@ -19,9 +18,9 @@
from modelgauge.sut import SUTCompletion, SUTResponse
from modelgauge.sut_registry import SUTS
from modelgauge.suts.demo_01_yes_no_sut import DemoYesNoResponse
from modelgauge_tests.fake_annotator import FakeAnnotator

from modelbench_tests.test_run_journal import FakeJournal, reader_for
from modelgauge_tests.fake_annotator import FakeAnnotator
from modelgauge_tests.fake_sut import FakeSUT

# fix pytest autodiscovery issue; see https://github.com/pytest-dev/pytest/issues/12749
Expand Down Expand Up @@ -641,6 +640,7 @@ def test_basic_benchmark_run(self, tmp_path, a_sut, fake_secrets, benchmark):
"finished run",
"cache info",
"cache info",
"closing journal",
]
# a BenchmarkScore keeps track of the various numbers used to arrive at a score
# so we can check its work. We make sure that log is in the journal.
Expand Down

0 comments on commit 6046471

Please sign in to comment.