diff --git a/garak/generators/ggml.py b/garak/generators/ggml.py index bc31c5fe8..3f04cb278 100644 --- a/garak/generators/ggml.py +++ b/garak/generators/ggml.py @@ -41,12 +41,15 @@ class GgmlGenerator(Generator): "exception_on_failure": True, "first_call": True, "key_env_var": ENV_VAR, + "extra_ggml_flags": ["-no-cnv"], + "extra_ggml_params": dict(), } generator_family_name = "ggml" - def command_params(self): - return { + def _command_args_list(self): + command_list = [] + params = { "-m": self.name, "-n": self.max_tokens, "--repeat-penalty": self.repeat_penalty, @@ -56,7 +59,15 @@ def command_params(self): "--top-p": self.top_p, "--temp": self.temperature, "-s": self.seed, - } + } | self.extra_ggml_params + # test all params for None type + for key, value in params.items(): + if value is not None: + command_list.append(key) + command_list.append(value) + if isinstance(self.extra_ggml_flags, list): + command_list.extend(self.extra_ggml_flags) + return command_list def __init__(self, name="", config_root=_config): self.name = name @@ -107,10 +118,7 @@ def _call_model( prompt, ] # test all params for None type - for key, value in self.command_params().items(): - if value is not None: - command.append(key) - command.append(value) + command.extend(self._command_args_list()) command = [str(param) for param in command] if _config.system.verbose > 1: print("GGML invoked with", command) diff --git a/tests/generators/test_ggml.py b/tests/generators/test_ggml.py index 3ed2a212a..3a12f6f5b 100644 --- a/tests/generators/test_ggml.py +++ b/tests/generators/test_ggml.py @@ -45,3 +45,36 @@ def test_init_good_model(): g = garak.generators.ggml.GgmlGenerator(file.name) os.remove(file.name) assert type(g) is garak.generators.ggml.GgmlGenerator + + +def test_command_args_list(): + """ensure command list overrides apply and `extra_ggml_params` are in correct relative order""" + with tempfile.NamedTemporaryFile(suffix="_test_model.gguf", delete=False) as file: + file.write(garak.generators.ggml.GGUF_MAGIC) + file.close() + + gen_config = { + "extra_ggml_flags": [ + "test_value", + "another_value", + ], + "extra_ggml_params": { + "custom_param": "custom_value", + }, + } + + config_root = {"generators": {"ggml": {"GgmlGenerator": gen_config}}} + + g = garak.generators.ggml.GgmlGenerator(file.name, config_root=config_root) + arg_list = g._command_args_list() + for arg in gen_config["extra_ggml_flags"]: + assert arg in arg_list + for arg, value in gen_config["extra_ggml_params"].items(): + assert arg in arg_list + assert value in arg_list + arg_index = arg_list.index(arg) + value_index = arg_list.index(value) + assert arg_index + 1 == value_index + + os.remove(file.name) + assert type(g) is garak.generators.ggml.GgmlGenerator