From bc239ca32fe0f42d1ee011ceaa33b035c56a73b8 Mon Sep 17 00:00:00 2001 From: Leon Derczynski Date: Fri, 19 Jan 2024 07:55:00 +0100 Subject: [PATCH] replay.Repeat now preserves attempt when restoring generator max_tokens --- garak/harnesses/base.py | 7 ++++--- garak/probes/base.py | 4 ++-- garak/probes/replay.py | 3 ++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/garak/harnesses/base.py b/garak/harnesses/base.py index 7c407e959..c359d22a5 100644 --- a/garak/harnesses/base.py +++ b/garak/harnesses/base.py @@ -20,7 +20,7 @@ import tqdm -from garak.attempt import * +import garak.attempt from garak import _config from garak import _plugins @@ -31,7 +31,7 @@ class Harness: active = True def __init__(self): - logging.info(f"harness init: %s", self) + logging.info("harness init: %s", self) def _load_buffs(self, buffs: List) -> None: """load buff instances into global config @@ -90,6 +90,7 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None: if not probe: continue attempt_results = probe.probe(model) + assert isinstance(attempt_results, list) eval_outputs, eval_results = [], defaultdict(list) first_detector = True @@ -109,7 +110,7 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None: first_detector = False for attempt in attempt_results: - attempt.status = ATTEMPT_COMPLETE + attempt.status = garak.attempt.ATTEMPT_COMPLETE _config.transient.reportfile.write(json.dumps(attempt.as_dict()) + "\n") if len(attempt_results) == 0: diff --git a/garak/probes/base.py b/garak/probes/base.py index b28d9a15c..11e868ffe 100644 --- a/garak/probes/base.py +++ b/garak/probes/base.py @@ -42,10 +42,10 @@ def __init__(self): print( f"loading {Style.BRIGHT}{Fore.LIGHTYELLOW_EX}probe: {Style.RESET_ALL}{self.probename}" ) - logging.info(f"probe init: {self}") + logging.info("probe init: {self}") if "description" not in dir(self): if self.__doc__: - self.description = self.__doc__.split("\n")[0] + self.description = self.__doc__.split("\n", maxsplit=1)[0] else: self.description = "" diff --git a/garak/probes/replay.py b/garak/probes/replay.py index c899841cd..cfe9b795f 100644 --- a/garak/probes/replay.py +++ b/garak/probes/replay.py @@ -67,9 +67,10 @@ def _generator_precall_hook(self, generator, attempt=None): self.generator_orig_tokens = self.generator.max_tokens self.generator.max_tokens = self.new_max_tokens - def _postprocess_hook(self, attempt): + def _postprocess_hook(self, attempt) -> Attempt: if self.override_maxlen and self.generator_orig_tokens is not None: self.generator.max_tokens = self.generator_orig_tokens + return attempt class RepeatExtended(Repeat):