Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add include_text parameter to SFT dataloaders #8198

Merged
merged 2 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def _process_example(self, example):
# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in ['conversations']}
result['metadata'] = metadata
if self.output_original_text:
result['metadata']['conversations'] = example['conversations']

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
truncation_method: str = 'right',
special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token}
is_test: bool = False,
output_original_text: bool = False,
):
"""
file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'}
Expand All @@ -81,6 +82,7 @@ def __init__(
truncation_method: Truncation from which position. Options: ['left', 'right']
special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
is_test: Whether this dataset is the test split.
output_original_text (bool): if true, will keep the original text in the output alongside the tokenized ids.
"""
self.tokenizer = tokenizer
self.file_path = file_path
Expand All @@ -102,6 +104,7 @@ def __init__(
self.tokens_to_generate = tokens_to_generate
self.truncation_method = truncation_method
self.is_test = is_test
self.output_original_text = output_original_text
if special_tokens is None:
self.special_tokens = {
"system_turn_start": "<extra_id_0>",
Expand Down Expand Up @@ -212,17 +215,17 @@ def _separate_template(self, prompt_template_values: List[str]):
Returns:
template_strings (List[str]): separated prompt_template with contexts/label placeholder filled with corresponding strings
template_strings_keys (List[str]): strings point to placeholder keys or <template>

Examples:
prompt_template = 'Context: {context} Question: {question} Answer: {label}'
prompt_template_values = ['xxx', 'yyy', 'zzz']

# tokenizer.space_sensitive = True
template_strings = ['Context:', ' xxx', ' Question:', ' yyy', ' Answer:', ' zzz']
template_strings = ['Context:', ' xxx', ' Question:', ' yyy', ' Answer:', ' zzz']

# tokenizer.space_sensitive = False
template_strings = ['Context:', ' xxx', 'Question:', 'yyy', 'Answer:', 'zzz']
template_strings = ['Context:', ' xxx', 'Question:', 'yyy', 'Answer:', 'zzz']

template_strings_keys = ['<template>', 'context', '<template>', 'question', '<template>', 'label']
"""
placeholders = [f'{{{k}}}' for k in self.prompt_template_keys]
Expand Down Expand Up @@ -260,7 +263,7 @@ def _separate_template(self, prompt_template_values: List[str]):
def _multiple_truncation(self, template_ids: List[List[int]], template_ids_keys: List[str]):
"""
Calculate total tokens and truncate multiple contexts in truncation_fields.

Args:
template_ids (List[List[int]]): the list of separate prompt_template ids.
template_ids_keys (List[str]): the list of placeholder keys or <template> (used to check key in truncation_fields).
Expand Down Expand Up @@ -368,6 +371,9 @@ def _process_example(self, example):

# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in self.prompt_template_keys}
if self.output_original_text:
for orig_text, text_key in zip(template_strings, template_strings_keys):
metadata[text_key] = orig_text

processed_example = {
'input_ids': input_ids,
Expand Down
Loading