Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions garak/generators/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ class GgmlGenerator(Generator):
"exception_on_failure": True,
"first_call": True,
"key_env_var": ENV_VAR,
"extra_ggml_flags": ["-no-cnv"],
"extra_ggml_params": dict(),
}

generator_family_name = "ggml"

def command_params(self):
return {
def _command_args_list(self):
command_list = []
params = {
"-m": self.name,
"-n": self.max_tokens,
"--repeat-penalty": self.repeat_penalty,
Expand All @@ -56,7 +59,15 @@ def command_params(self):
"--top-p": self.top_p,
"--temp": self.temperature,
"-s": self.seed,
}
} | self.extra_ggml_params
# test all params for None type
for key, value in params.items():
if value is not None:
command_list.append(key)
command_list.append(value)
if isinstance(self.extra_ggml_flags, list):
command_list.extend(self.extra_ggml_flags)
return command_list

def __init__(self, name="", config_root=_config):
self.name = name
Expand Down Expand Up @@ -107,10 +118,7 @@ def _call_model(
prompt,
]
# test all params for None type
for key, value in self.command_params().items():
if value is not None:
command.append(key)
command.append(value)
command.extend(self._command_args_list())
command = [str(param) for param in command]
if _config.system.verbose > 1:
print("GGML invoked with", command)
Expand Down
33 changes: 33 additions & 0 deletions tests/generators/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,36 @@ def test_init_good_model():
g = garak.generators.ggml.GgmlGenerator(file.name)
os.remove(file.name)
assert type(g) is garak.generators.ggml.GgmlGenerator


def test_command_args_list():
"""ensure command list overrides apply and `extra_ggml_params` are in correct relative order"""
with tempfile.NamedTemporaryFile(suffix="_test_model.gguf", delete=False) as file:
file.write(garak.generators.ggml.GGUF_MAGIC)
file.close()

gen_config = {
"extra_ggml_flags": [
"test_value",
"another_value",
],
"extra_ggml_params": {
"custom_param": "custom_value",
},
}

config_root = {"generators": {"ggml": {"GgmlGenerator": gen_config}}}

g = garak.generators.ggml.GgmlGenerator(file.name, config_root=config_root)
arg_list = g._command_args_list()
for arg in gen_config["extra_ggml_flags"]:
assert arg in arg_list
for arg, value in gen_config["extra_ggml_params"].items():
assert arg in arg_list
assert value in arg_list
arg_index = arg_list.index(arg)
value_index = arg_list.index(value)
assert arg_index + 1 == value_index

os.remove(file.name)
assert type(g) is garak.generators.ggml.GgmlGenerator