From 2626ece305fb5f7dcd1cfecff261633fd3a1b9e0 Mon Sep 17 00:00:00 2001 From: Igor Gitman Date: Sun, 21 Jan 2024 14:02:23 -0800 Subject: [PATCH] Add include_text parameter to SFT dataloaders (#8198) * Add include_text parameter to SFT dataloaders Signed-off-by: Igor Gitman * Rename include_text -> output_original_text Signed-off-by: Igor Gitman --------- Signed-off-by: Igor Gitman --- .../megatron/gpt_sft_chat_dataset.py | 2 ++ .../megatron/gpt_sft_dataset.py | 20 ++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py index 96cc57a300b8..3d5d7effc9de 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py @@ -344,6 +344,8 @@ def _process_example(self, example): # store metadata in dataset, in case user may have keys required in the prediction json files metadata = {k: v for k, v in example.items() if k not in ['conversations']} result['metadata'] = metadata + if self.output_original_text: + result['metadata']['conversations'] = example['conversations'] return result diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 07cddce23a0c..93c4a33419bb 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -57,6 +57,7 @@ def __init__( hf_dataset: bool = False, truncation_method: str = 'right', special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} + output_original_text: bool = False, ): """ file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} @@ -79,6 +80,7 @@ def __init__( hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. truncation_method: Truncation from which position. Options: ['left', 'right'] special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + output_original_text (bool): if true, will keep the original text in the output alongside the tokenized ids. """ self.tokenizer = tokenizer self.file_path = file_path @@ -99,6 +101,7 @@ def __init__( self.virtual_tokens = virtual_tokens self.tokens_to_generate = tokens_to_generate self.truncation_method = truncation_method + self.output_original_text = output_original_text if special_tokens is None: self.special_tokens = { "system_turn_start": "", @@ -209,17 +212,17 @@ def _separate_template(self, prompt_template_values: List[str]): Returns: template_strings (List[str]): separated prompt_template with contexts/label placeholder filled with corresponding strings template_strings_keys (List[str]): strings point to placeholder keys or