Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support value attribution condition #6934

Merged
merged 24 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
web_port: 9889 # the port number of the web server
chat: False # use the chat interface
chat: False # use the chat interface
chatbot_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call this steerlm_config instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the paper is in anonymity. don't want to mention the name of the method.

value: False # whether to inject the value attributes
user: User
assistant: Assistant
system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
8 changes: 7 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import asyncio
import os
import threading
from functools import partial

import torch
from omegaconf import OmegaConf, open_dict
Expand Down Expand Up @@ -301,7 +302,12 @@ def main(cfg) -> None:
if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
if cfg.chat:
web_ui = get_chatbot_demo
defaults = {
'user': cfg.chatbot_config.user,
'assistant': cfg.chatbot_config.assistant,
'system': cfg.chatbot_config.system,
}
web_ui = partial(get_chatbot_demo, defaults=defaults, value=cfg.chatbot_config.value)
else:
web_ui = get_demo
loop = asyncio.new_event_loop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.utils import logging
Expand All @@ -29,25 +30,65 @@
SYSTEM_TOKEN = "<extra_id_0>System\n"
TURN_TOKEN = "<extra_id_1>"

GUARD_RAIL_INSTRUCTION = {
"TEXT_TO_CANONICAL_FORM": "Given a dialogue, for each turn you need to generate a short summary called a canonical form. Generate the canonical form for the last turn in the dialogue.",
"CANONICAL_FORM_TO_TEXT": "Given a dialogue, for each turn we also have a short summary called a canonical form. Generate the canonical form given the last turn message and canonical form. Then generate the message.",
TYPE_INSTRUCTION = {
'TEXT_TO_VALUE': "",
'VALUE_TO_TEXT': '',
}


def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer, mask_role):
def _mask_targets(
target,
tokenized_lens,
speakers,
header_len,
s_ids,
tokenizer,
mask_role,
gtype,
extra_id_2_token_id,
new_line_token_id,
):
""" This function masks the tokens so the loss is computed only on the non-masked role's responses.
For 'TEXT_TO_VALUE' type, the loss is computed on the value attributes.

Args:
target (Tensor): input ids
tokenized_lens (List[int]): array of lengths of each turns
speakers (List[str]): array of speakers of each turns
header_len (int): the system prompt length
s_ids (List[Tensor]): array of tokenized ids of each turns
tokenizer (TokenizerSpec): tokenizer object
mask_role (str): the speaker id to be masked from loss computation
gtype (str): either 'TEXT_TO_VALUE' or 'VALUE_TO_TEXT'
extra_id_2_token_id (int): <extra_id_2> token id
new_line_token_id (int): new line token id

"""
cur_idx = header_len
tgt_len = target.shape[0]
for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)):
# note, sentence piece will add extra empty token in front. s_id has that extra token too
skip_name_len = len(tokenizer.text_to_ids(TURN_TOKEN + speaker + END_NAME_SIGNAL))
# note, sentence piece will add extra empty token in front. has to compute the diff
id1 = tokenizer.text_to_ids("<extra_id_1>")
id2 = tokenizer.text_to_ids("<extra_id_1>" + TURN_TOKEN + speaker + END_NAME_SIGNAL)
skip_name_len = len(id2) - len(id1)
if extra_id_2_token_id is None:
raise ValueError("extra_id_2 is not in the vocabulary")
if (s_id == extra_id_2_token_id).any().item():
if gtype == 'VALUE_TO_TEXT':
# if contains the token <extra_id_2>
assert skip_name_len == torch.where((s_id == extra_id_2_token_id))[0].item()
# find new line token id 14
more_skip_len = torch.where((s_id[skip_name_len:] == new_line_token_id))[0][0].item() + 1
skip_name_len += more_skip_len
elif gtype == 'TEXT_TO_VALUE':
skip_name_len = torch.where((s_id == extra_id_2_token_id))[0].item() + 1
if cur_idx >= tgt_len:
break
elif cur_idx + tokenized_len < tgt_len:
# Check whether the mask is applied to the correct position, the first token is turn token: <extra_id_1>
# s_id[2:] skips the artifact empty token and the turn token
# target[cur_idx + 1:cur_idx + tokenized_len] skip the turn token
if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[2:]):
if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[1:]):
logging.warning("a sentence mismatches the corresponding piece " "in the conversation")
if i == 0:
# mask the first turn completely to provide at least one turn as context
Expand All @@ -57,14 +98,21 @@ def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer
target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX
else:
# mask up to the name end, need to remove one as skip name has an extra artifact empty token
target[cur_idx : cur_idx + skip_name_len - 1] = IGNORE_INDEX
target[cur_idx : cur_idx + skip_name_len] = IGNORE_INDEX
cur_idx += tokenized_len


def cannonical_form_formater(cannoical_form):
return f'<extra_id_2>{cannoical_form}\n'


def response_value_formater(label):
if isinstance(label, str):
return '<extra_id_2>' + label + '\n'
else:
raise ValueError(f'Unknown label type {type(label)}, only str type is supported')


def _add_speaker_and_signal(header, source, mask_role, gtype):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = ""
Expand All @@ -76,56 +124,57 @@ def _add_speaker_and_signal(header, source, mask_role, gtype):
sentence["value"] = (
BEGIN_SIGNAL + role_token + sentence_from + END_NAME_SIGNAL + sentence["value"] + END_SIGNAL
)
elif gtype == "TEXT_TO_CANONICAL_FORM":
elif gtype == "VALUE_TO_TEXT":
sentence["value"] = (
BEGIN_SIGNAL
+ role_token
+ sentence_from
+ END_NAME_SIGNAL
+ (response_value_formater(sentence['label']) if 'label' in sentence else '')
Zhilin123 marked this conversation as resolved.
Show resolved Hide resolved
+ sentence["value"]
+ END_SIGNAL
+ cannonical_form_formater(sentence['canonical_form'])
)
elif gtype == "CANONICAL_FORM_TO_TEXT":
elif gtype == "TEXT_TO_VALUE":
sentence["value"] = (
BEGIN_SIGNAL
+ role_token
+ sentence_from
+ END_NAME_SIGNAL
+ cannonical_form_formater(sentence['canonical_form'])
+ sentence["value"]
+ END_SIGNAL
+ (response_value_formater(sentence['label']) if 'label' in sentence else '')
)
else:
raise ValueError(f"source type {gtype} not supported")
raise ValueError(
f"source type {gtype} not supported, only 'VALUE_TO_TEXT' and 'TEXT_TO_VALUE' are supported"
)
conversation += sentence["value"]
# if the last turn is not masked, add next token start token to the end, which will be included for loss calculation
if sentence_from != mask_role and i == len(source) - 1:
conversation += TURN_TOKEN
return conversation


def preprocess(
source: dict, tokenizer: TokenizerSpec,
):
def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, new_line_token_id: int):
"""
Given a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
canonical_type = None
data_type = None
if 'type' in source:
canonical_type = source['type']
assert canonical_type in GUARD_RAIL_INSTRUCTION, f"source type {canonical_type} not supported"
data_type = source['type']
assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported"
# add end signal and concatenate together
conversation = source['system']
if canonical_type is not None:
conversation = conversation + '\n' + GUARD_RAIL_INSTRUCTION[canonical_type]
if data_type is not None:
if TYPE_INSTRUCTION[data_type] != '':
conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type]
mask_role = source.get('mask', 'User')
header = f"{SYSTEM_TOKEN}{conversation}\n\n"
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, canonical_type)
header = f"{SYSTEM_TOKEN}{conversation}"
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type)
# tokenize conversations
input_ids = tokenizer.text_to_ids(conversation)
target = copy.deepcopy(input_ids)
Expand All @@ -134,37 +183,76 @@ def preprocess(
ids = []
tokenized_lens = []
for s in source['conversations']:
tokenized_sentence = tokenizer.text_to_ids(s["value"])
ids.append(torch.tensor(tokenized_sentence))
# remove one token as it adds an empty token in front
tokenized_lens.append(len(tokenized_sentence) - 1)
if isinstance(tokenizer, SentencePieceTokenizer):
tokenized_sentence = tokenizer.text_to_ids(s["value"])
ids.append(torch.tensor(tokenized_sentence)[1:])
# remove one token as it adds an empty token in front
tokenized_lens.append(len(tokenized_sentence) - 1)
else:
tokenized_sentence = tokenizer.text_to_ids(s["value"])
ids.append(torch.tensor(tokenized_sentence))
# remove one token as it adds an empty token in front
tokenized_lens.append(len(tokenized_sentence))
speakers = [sentence["from"] for sentence in source['conversations']]
assert mask_role in speakers, "mask role not in the conversation"
target = torch.LongTensor(target)
# not going to train on the header
target[:header_len] = IGNORE_INDEX
input_ids = torch.LongTensor(input_ids)

_mask_targets(target, tokenized_lens, speakers, header_len, ids, tokenizer, mask_role)
_mask_targets(
target,
tokenized_lens,
speakers,
header_len,
ids,
tokenizer,
mask_role,
data_type,
extra_id_2_token_id,
new_line_token_id,
)
mask = (target != IGNORE_INDEX).bool()
assert mask.sum().item() != 0, "mask is empty"
return dict(input_ids=input_ids, mask=mask)


def _check_token_in_vocab(tokenizer, token):
ids = tokenizer.text_to_ids(token)
if isinstance(tokenizer, SentencePieceTokenizer):
return len(ids) == 2
else:
return len(ids) == 1


class GPTSFTChatDataset(GPTSFTDataset):
def _build_samples_mapping(self):
super()._build_samples_mapping()
assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported"
assert '<extra_id_0>' in self.tokenizer.vocab, "<extra_id_0> not in the tokenizer vocab. not supported"
assert '<extra_id_1>' in self.tokenizer.vocab, "<extra_id_1> not in the tokenizer vocab. not supported"
assert _check_token_in_vocab(
self.tokenizer, '<extra_id_0>'
), "<extra_id_0> not in the tokenizer vocab. not supported"
assert _check_token_in_vocab(
self.tokenizer, '<extra_id_1>'
), "<extra_id_1> not in the tokenizer vocab. not supported"
# calcuilate <extra_id_2> id value
if _check_token_in_vocab(self.tokenizer, '<extra_id_2>'):
ids_1 = self.tokenizer.text_to_ids('<extra_id_1><extra_id_2>')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a typo?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we can't get the text_to_ids('<extra_id_2>') directly. This looks pretty hacky

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to handle the sentencepiece tokenizer which adds a special token in front. A hacky way I agree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw this check is super slow since self.tokenizer.vocab is a list not a dictionary/set. It is actually faster to do len(self.tokenizer.text_to_ids('<extra_id_2>')) == 1.

Copy link
Collaborator

@Zhilin123 Zhilin123 Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, nit but isn't it just gonna be self.extra_id_2_token_id = self.tokenizer.text_to_ids('<extra_id_2>')[-1]

ids_2 = self.tokenizer.text_to_ids('<extra_id_1>')
self.extra_id_2_token_id = ids_1[len(ids_2) :][0]
else:
self.extra_id_2_token_id = None
ids_1 = self.tokenizer.text_to_ids('<extra_id_1>\n')
ids_2 = self.tokenizer.text_to_ids('<extra_id_1>')
self.new_line_token_id = ids_1[len(ids_2) :][0]

def _process_example(self, example):
"""
Create an example by concatenating text and answer.
Truncation is carried out when needed, but it is performed only on the prompt side.
BOS, EOS, and SEP, are added if specified.
"""
result = preprocess(example, self.tokenizer)
result = preprocess(example, self.tokenizer, self.extra_id_2_token_id, self.new_line_token_id)

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
from typing import Any, Dict, Optional, Union

import omegaconf
import torch
from omegaconf import open_dict
from omegaconf.dictconfig import DictConfig
Expand Down Expand Up @@ -222,6 +223,10 @@ def _build_tokenizer(self):
legacy=legacy,
)

if self._cfg.tokenizer.get('additional_special_tokens', None) is not None:
tokens_list = omegaconf.OmegaConf.to_object(self._cfg.tokenizer.additional_special_tokens)
self.tokenizer.add_special_tokens({'additional_special_tokens': tokens_list})

def on_train_start(self) -> None:
super().on_train_start()
self.init_global_step = self.trainer.global_step
Expand Down
Loading
Loading