-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support value attribution condition #6934
Changes from all commits
b6fd0bc
90f9896
2a6d169
eb2e824
1a9dc88
271d8fc
4b417b2
412fdf7
b320575
67d1305
8c0968d
49448a6
3a28c78
6cce277
bf46e8d
7b706f2
eacf248
d5ea1f7
7cc4c81
aad6866
ea270be
f04895e
c8b7d49
33922c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
import torch | ||
|
||
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer | ||
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec | ||
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset | ||
from nemo.utils import logging | ||
|
@@ -29,25 +30,65 @@ | |
SYSTEM_TOKEN = "<extra_id_0>System\n" | ||
TURN_TOKEN = "<extra_id_1>" | ||
|
||
GUARD_RAIL_INSTRUCTION = { | ||
"TEXT_TO_CANONICAL_FORM": "Given a dialogue, for each turn you need to generate a short summary called a canonical form. Generate the canonical form for the last turn in the dialogue.", | ||
"CANONICAL_FORM_TO_TEXT": "Given a dialogue, for each turn we also have a short summary called a canonical form. Generate the canonical form given the last turn message and canonical form. Then generate the message.", | ||
TYPE_INSTRUCTION = { | ||
'TEXT_TO_VALUE': "", | ||
'VALUE_TO_TEXT': '', | ||
} | ||
|
||
|
||
def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer, mask_role): | ||
def _mask_targets( | ||
target, | ||
tokenized_lens, | ||
speakers, | ||
header_len, | ||
s_ids, | ||
tokenizer, | ||
mask_role, | ||
gtype, | ||
extra_id_2_token_id, | ||
new_line_token_id, | ||
): | ||
""" This function masks the tokens so the loss is computed only on the non-masked role's responses. | ||
For 'TEXT_TO_VALUE' type, the loss is computed on the value attributes. | ||
|
||
Args: | ||
target (Tensor): input ids | ||
tokenized_lens (List[int]): array of lengths of each turns | ||
speakers (List[str]): array of speakers of each turns | ||
header_len (int): the system prompt length | ||
s_ids (List[Tensor]): array of tokenized ids of each turns | ||
tokenizer (TokenizerSpec): tokenizer object | ||
mask_role (str): the speaker id to be masked from loss computation | ||
gtype (str): either 'TEXT_TO_VALUE' or 'VALUE_TO_TEXT' | ||
extra_id_2_token_id (int): <extra_id_2> token id | ||
new_line_token_id (int): new line token id | ||
|
||
""" | ||
cur_idx = header_len | ||
tgt_len = target.shape[0] | ||
for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)): | ||
# note, sentence piece will add extra empty token in front. s_id has that extra token too | ||
skip_name_len = len(tokenizer.text_to_ids(TURN_TOKEN + speaker + END_NAME_SIGNAL)) | ||
# note, sentence piece will add extra empty token in front. has to compute the diff | ||
id1 = tokenizer.text_to_ids("<extra_id_1>") | ||
id2 = tokenizer.text_to_ids("<extra_id_1>" + TURN_TOKEN + speaker + END_NAME_SIGNAL) | ||
skip_name_len = len(id2) - len(id1) | ||
if extra_id_2_token_id is None: | ||
raise ValueError("extra_id_2 is not in the vocabulary") | ||
if (s_id == extra_id_2_token_id).any().item(): | ||
if gtype == 'VALUE_TO_TEXT': | ||
# if contains the token <extra_id_2> | ||
assert skip_name_len == torch.where((s_id == extra_id_2_token_id))[0].item() | ||
# find new line token id 14 | ||
more_skip_len = torch.where((s_id[skip_name_len:] == new_line_token_id))[0][0].item() + 1 | ||
skip_name_len += more_skip_len | ||
elif gtype == 'TEXT_TO_VALUE': | ||
skip_name_len = torch.where((s_id == extra_id_2_token_id))[0].item() + 1 | ||
if cur_idx >= tgt_len: | ||
break | ||
elif cur_idx + tokenized_len < tgt_len: | ||
# Check whether the mask is applied to the correct position, the first token is turn token: <extra_id_1> | ||
# s_id[2:] skips the artifact empty token and the turn token | ||
# target[cur_idx + 1:cur_idx + tokenized_len] skip the turn token | ||
if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[2:]): | ||
if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[1:]): | ||
logging.warning("a sentence mismatches the corresponding piece " "in the conversation") | ||
if i == 0: | ||
# mask the first turn completely to provide at least one turn as context | ||
|
@@ -57,14 +98,21 @@ def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer | |
target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX | ||
else: | ||
# mask up to the name end, need to remove one as skip name has an extra artifact empty token | ||
target[cur_idx : cur_idx + skip_name_len - 1] = IGNORE_INDEX | ||
target[cur_idx : cur_idx + skip_name_len] = IGNORE_INDEX | ||
cur_idx += tokenized_len | ||
|
||
|
||
def cannonical_form_formater(cannoical_form): | ||
return f'<extra_id_2>{cannoical_form}\n' | ||
|
||
|
||
def response_value_formater(label): | ||
if isinstance(label, str): | ||
return '<extra_id_2>' + label + '\n' | ||
else: | ||
raise ValueError(f'Unknown label type {type(label)}, only str type is supported') | ||
|
||
|
||
def _add_speaker_and_signal(header, source, mask_role, gtype): | ||
"""Add speaker and start/end signal on each round.""" | ||
BEGIN_SIGNAL = "" | ||
|
@@ -76,56 +124,57 @@ def _add_speaker_and_signal(header, source, mask_role, gtype): | |
sentence["value"] = ( | ||
BEGIN_SIGNAL + role_token + sentence_from + END_NAME_SIGNAL + sentence["value"] + END_SIGNAL | ||
) | ||
elif gtype == "TEXT_TO_CANONICAL_FORM": | ||
elif gtype == "VALUE_TO_TEXT": | ||
sentence["value"] = ( | ||
BEGIN_SIGNAL | ||
+ role_token | ||
+ sentence_from | ||
+ END_NAME_SIGNAL | ||
+ (response_value_formater(sentence['label']) if 'label' in sentence else '') | ||
Zhilin123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
+ sentence["value"] | ||
+ END_SIGNAL | ||
+ cannonical_form_formater(sentence['canonical_form']) | ||
) | ||
elif gtype == "CANONICAL_FORM_TO_TEXT": | ||
elif gtype == "TEXT_TO_VALUE": | ||
sentence["value"] = ( | ||
BEGIN_SIGNAL | ||
+ role_token | ||
+ sentence_from | ||
+ END_NAME_SIGNAL | ||
+ cannonical_form_formater(sentence['canonical_form']) | ||
+ sentence["value"] | ||
+ END_SIGNAL | ||
+ (response_value_formater(sentence['label']) if 'label' in sentence else '') | ||
) | ||
else: | ||
raise ValueError(f"source type {gtype} not supported") | ||
raise ValueError( | ||
f"source type {gtype} not supported, only 'VALUE_TO_TEXT' and 'TEXT_TO_VALUE' are supported" | ||
) | ||
conversation += sentence["value"] | ||
# if the last turn is not masked, add next token start token to the end, which will be included for loss calculation | ||
if sentence_from != mask_role and i == len(source) - 1: | ||
conversation += TURN_TOKEN | ||
return conversation | ||
|
||
|
||
def preprocess( | ||
source: dict, tokenizer: TokenizerSpec, | ||
): | ||
def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, new_line_token_id: int): | ||
""" | ||
Given a conversation list. This transform: | ||
1. Add signal '### ' at the beginning each sentence, with end signal '\n'; | ||
2. Concatenate conversations together; | ||
3. Tokenize the concatenated conversation; | ||
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. | ||
""" | ||
canonical_type = None | ||
data_type = None | ||
if 'type' in source: | ||
canonical_type = source['type'] | ||
assert canonical_type in GUARD_RAIL_INSTRUCTION, f"source type {canonical_type} not supported" | ||
data_type = source['type'] | ||
assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported" | ||
# add end signal and concatenate together | ||
conversation = source['system'] | ||
if canonical_type is not None: | ||
conversation = conversation + '\n' + GUARD_RAIL_INSTRUCTION[canonical_type] | ||
if data_type is not None: | ||
if TYPE_INSTRUCTION[data_type] != '': | ||
conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type] | ||
mask_role = source.get('mask', 'User') | ||
header = f"{SYSTEM_TOKEN}{conversation}\n\n" | ||
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, canonical_type) | ||
header = f"{SYSTEM_TOKEN}{conversation}" | ||
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type) | ||
# tokenize conversations | ||
input_ids = tokenizer.text_to_ids(conversation) | ||
target = copy.deepcopy(input_ids) | ||
|
@@ -134,37 +183,76 @@ def preprocess( | |
ids = [] | ||
tokenized_lens = [] | ||
for s in source['conversations']: | ||
tokenized_sentence = tokenizer.text_to_ids(s["value"]) | ||
ids.append(torch.tensor(tokenized_sentence)) | ||
# remove one token as it adds an empty token in front | ||
tokenized_lens.append(len(tokenized_sentence) - 1) | ||
if isinstance(tokenizer, SentencePieceTokenizer): | ||
tokenized_sentence = tokenizer.text_to_ids(s["value"]) | ||
ids.append(torch.tensor(tokenized_sentence)[1:]) | ||
# remove one token as it adds an empty token in front | ||
tokenized_lens.append(len(tokenized_sentence) - 1) | ||
else: | ||
tokenized_sentence = tokenizer.text_to_ids(s["value"]) | ||
ids.append(torch.tensor(tokenized_sentence)) | ||
# remove one token as it adds an empty token in front | ||
tokenized_lens.append(len(tokenized_sentence)) | ||
speakers = [sentence["from"] for sentence in source['conversations']] | ||
assert mask_role in speakers, "mask role not in the conversation" | ||
target = torch.LongTensor(target) | ||
# not going to train on the header | ||
target[:header_len] = IGNORE_INDEX | ||
input_ids = torch.LongTensor(input_ids) | ||
|
||
_mask_targets(target, tokenized_lens, speakers, header_len, ids, tokenizer, mask_role) | ||
_mask_targets( | ||
target, | ||
tokenized_lens, | ||
speakers, | ||
header_len, | ||
ids, | ||
tokenizer, | ||
mask_role, | ||
data_type, | ||
extra_id_2_token_id, | ||
new_line_token_id, | ||
) | ||
mask = (target != IGNORE_INDEX).bool() | ||
assert mask.sum().item() != 0, "mask is empty" | ||
return dict(input_ids=input_ids, mask=mask) | ||
|
||
|
||
def _check_token_in_vocab(tokenizer, token): | ||
ids = tokenizer.text_to_ids(token) | ||
if isinstance(tokenizer, SentencePieceTokenizer): | ||
return len(ids) == 2 | ||
else: | ||
return len(ids) == 1 | ||
|
||
|
||
class GPTSFTChatDataset(GPTSFTDataset): | ||
def _build_samples_mapping(self): | ||
super()._build_samples_mapping() | ||
assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported" | ||
assert '<extra_id_0>' in self.tokenizer.vocab, "<extra_id_0> not in the tokenizer vocab. not supported" | ||
assert '<extra_id_1>' in self.tokenizer.vocab, "<extra_id_1> not in the tokenizer vocab. not supported" | ||
assert _check_token_in_vocab( | ||
self.tokenizer, '<extra_id_0>' | ||
), "<extra_id_0> not in the tokenizer vocab. not supported" | ||
assert _check_token_in_vocab( | ||
self.tokenizer, '<extra_id_1>' | ||
), "<extra_id_1> not in the tokenizer vocab. not supported" | ||
# calcuilate <extra_id_2> id value | ||
if _check_token_in_vocab(self.tokenizer, '<extra_id_2>'): | ||
ids_1 = self.tokenizer.text_to_ids('<extra_id_1><extra_id_2>') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this a typo? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why we can't get the text_to_ids('<extra_id_2>') directly. This looks pretty hacky There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is to handle the sentencepiece tokenizer which adds a special token in front. A hacky way I agree. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw this check is super slow since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, nit but isn't it just gonna be |
||
ids_2 = self.tokenizer.text_to_ids('<extra_id_1>') | ||
self.extra_id_2_token_id = ids_1[len(ids_2) :][0] | ||
else: | ||
self.extra_id_2_token_id = None | ||
ids_1 = self.tokenizer.text_to_ids('<extra_id_1>\n') | ||
ids_2 = self.tokenizer.text_to_ids('<extra_id_1>') | ||
self.new_line_token_id = ids_1[len(ids_2) :][0] | ||
|
||
def _process_example(self, example): | ||
""" | ||
Create an example by concatenating text and answer. | ||
Truncation is carried out when needed, but it is performed only on the prompt side. | ||
BOS, EOS, and SEP, are added if specified. | ||
""" | ||
result = preprocess(example, self.tokenizer) | ||
result = preprocess(example, self.tokenizer, self.extra_id_2_token_id, self.new_line_token_id) | ||
|
||
return result | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you call this steerlm_config instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the paper is in anonymity. don't want to mention the name of the method.