Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve PEFT UX #8131

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
hf_dataset: bool = False,
truncation_method: str = 'right',
special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token}
is_test: bool = False,
):
"""
file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'}
Expand All @@ -79,6 +80,7 @@ def __init__(
hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.
truncation_method: Truncation from which position. Options: ['left', 'right']
special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
is_test: Whether this dataset is the test split.
"""
self.tokenizer = tokenizer
self.file_path = file_path
Expand All @@ -99,6 +101,7 @@ def __init__(
self.virtual_tokens = virtual_tokens
self.tokens_to_generate = tokens_to_generate
self.truncation_method = truncation_method
self.is_test = is_test
if special_tokens is None:
self.special_tokens = {
"system_turn_start": "<extra_id_0>",
Expand Down Expand Up @@ -317,7 +320,16 @@ def _process_example(self, example):
Truncation is carried out when needed, but it is performed only on the prompt side.
BOS, EOS, and SEP, are added if specified.
"""
prompt_template_values = [example[c].strip(' ') for c in self.prompt_template_keys]
prompt_template_values = []
for c in self.prompt_template_keys:
try:
prompt_template_values.append(example[c].strip(' '))
except KeyError as e:
if c == self.label_key and self.is_test:
# allow missing label during testing, if user only wants to do inference without calculating metrics
prompt_template_values.append("")
else:
raise e

template_strings, template_strings_keys = self._separate_template(prompt_template_values)
template_ids = [self.tokenizer.text_to_ids(s) for s in template_strings]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _build_dataset(self, data_cfg, is_train=True):
special_tokens=self.cfg.data.get(
'chat_prompt_tokens', None
), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
is_test=not is_train,
)
datasets.append(dataset)
if is_train:
Expand Down
4 changes: 1 addition & 3 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,8 @@ def _check_and_add_peft_cfg(self, peft_cfg):
layers = self.model.module.language_model.encoder.layers
else:
layers = self.model.language_model.encoder.layers
if layer_selection is None:
layer_selection = list(range(1, self.cfg.num_layers + 1))
for layer in layers:
if layer.layer_number in layer_selection:
if layer.layer_number in (layer_selection or list(range(1, self.cfg.num_layers + 1))):
for name, module in layer.named_modules():
self._check_and_add_adapter(
name, module, adapter_name, adapter_cfg, name_key_to_mcore_mixins
Expand Down
Loading