Skip to content

Commit

Permalink
Add include_text parameter to SFT dataloaders (NVIDIA#8198)
Browse files Browse the repository at this point in the history
* Add include_text parameter to SFT dataloaders

Signed-off-by: Igor Gitman <[email protected]>

* Rename include_text -> output_original_text

Signed-off-by: Igor Gitman <[email protected]>

---------

Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok authored and odelalleau committed Jan 31, 2024
1 parent 2c4c7d8 commit 2626ece
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def _process_example(self, example):
# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in ['conversations']}
result['metadata'] = metadata
if self.output_original_text:
result['metadata']['conversations'] = example['conversations']

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
hf_dataset: bool = False,
truncation_method: str = 'right',
special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token}
output_original_text: bool = False,
):
"""
file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'}
Expand All @@ -79,6 +80,7 @@ def __init__(
hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.
truncation_method: Truncation from which position. Options: ['left', 'right']
special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
output_original_text (bool): if true, will keep the original text in the output alongside the tokenized ids.
"""
self.tokenizer = tokenizer
self.file_path = file_path
Expand All @@ -99,6 +101,7 @@ def __init__(
self.virtual_tokens = virtual_tokens
self.tokens_to_generate = tokens_to_generate
self.truncation_method = truncation_method
self.output_original_text = output_original_text
if special_tokens is None:
self.special_tokens = {
"system_turn_start": "<extra_id_0>",
Expand Down Expand Up @@ -209,17 +212,17 @@ def _separate_template(self, prompt_template_values: List[str]):
Returns:
template_strings (List[str]): separated prompt_template with contexts/label placeholder filled with corresponding strings
template_strings_keys (List[str]): strings point to placeholder keys or <template>
Examples:
prompt_template = 'Context: {context} Question: {question} Answer: {label}'
prompt_template_values = ['xxx', 'yyy', 'zzz']
# tokenizer.space_sensitive = True
template_strings = ['Context:', ' xxx', ' Question:', ' yyy', ' Answer:', ' zzz']
template_strings = ['Context:', ' xxx', ' Question:', ' yyy', ' Answer:', ' zzz']
# tokenizer.space_sensitive = False
template_strings = ['Context:', ' xxx', 'Question:', 'yyy', 'Answer:', 'zzz']
template_strings = ['Context:', ' xxx', 'Question:', 'yyy', 'Answer:', 'zzz']
template_strings_keys = ['<template>', 'context', '<template>', 'question', '<template>', 'label']
"""
placeholders = [f'{{{k}}}' for k in self.prompt_template_keys]
Expand Down Expand Up @@ -257,7 +260,7 @@ def _separate_template(self, prompt_template_values: List[str]):
def _multiple_truncation(self, template_ids: List[List[int]], template_ids_keys: List[str]):
"""
Calculate total tokens and truncate multiple contexts in truncation_fields.
Args:
template_ids (List[List[int]]): the list of separate prompt_template ids.
template_ids_keys (List[str]): the list of placeholder keys or <template> (used to check key in truncation_fields).
Expand Down Expand Up @@ -356,6 +359,9 @@ def _process_example(self, example):

# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in self.prompt_template_keys}
if self.output_original_text:
for orig_text, text_key in zip(template_strings, template_strings_keys):
metadata[text_key] = orig_text

processed_example = {
'input_ids': input_ids,
Expand Down

0 comments on commit 2626ece

Please sign in to comment.