Skip to content

Commit

Permalink
Refactor huggingface config support (#742)
Browse files Browse the repository at this point in the history
* do not override config deprefix_prompt

Signed-off-by: Jeffrey Martin <[email protected]>

* improve code reuse

* consolidate `__init__` where possible
* shift generator or model object creation to `_load_client()`

Signed-off-by: Jeffrey Martin <[email protected]>

* crude implmentation of limitation on parallel generator call

Signed-off-by: Jeffrey Martin <[email protected]>

* add torch `mps` support & enabled passed pipeline params

* detect cuda vs mps vs cpu in a common way
* guard import of OptimimPipeline

Signed-off-by: Jeffrey Martin <[email protected]>

* enable hf model or pipeline config in `hf_args`

* support all generic `pipeline` args at all times
* adds `do_sample` when `model` is a parameter to the `Callable`
* adds `low_cpu_mem_usage` and all `pipeline` for `Callables` without `model`
* consolidates optimal device selection & set when not provided by config

Signed-off-by: Jeffrey Martin <[email protected]>

* amend yaml config example

* support merged dictionary in `Configurable`

Signed-off-by: Jeffrey Martin <[email protected]>

* free tokenizer in _clear_client

Signed-off-by: Jeffrey Martin <[email protected]>

* explicit device support

* raise error when passed negative device integer
* rename parameter tracking var
* remove unused import
* add tests for `_select_hf_device()`

Signed-off-by: Jeffrey Martin <[email protected]>

---------

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech authored Jun 27, 2024
1 parent d49fab0 commit 21dd343
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 138 deletions.
5 changes: 5 additions & 0 deletions garak/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def _apply_config(self, config):
)
):
continue
if isinstance(v, dict): # if value is an existing dictionary merge
v = getattr(self, k) | v
setattr(self, k, v) # This will set attribute to the full dictionary value

def _apply_missing_instance_defaults(self):
Expand All @@ -96,6 +98,9 @@ def _apply_missing_instance_defaults(self):
for k, v in self.DEFAULT_PARAMS.items():
if not hasattr(self, k):
setattr(self, k, v)
elif isinstance(v, dict):
v = v | getattr(self, k)
setattr(self, k, v)

def _validate_env_var(self):
if hasattr(self, "key_env_var"):
Expand Down
1 change: 1 addition & 0 deletions garak/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Generator(Configurable):

active = True
generator_family_name = None
parallel_capable = True

# support mainstream any-to-any large models
# legal element for str list `modality['in']`: 'text', 'image', 'audio', 'video', '3d'
Expand Down
Loading

0 comments on commit 21dd343

Please sign in to comment.