Skip to content

Commit

Permalink
Dialogue dataset (NVIDIA#6654)
Browse files Browse the repository at this point in the history
* chatbot interface

Signed-off-by: Yi Dong <[email protected]>

* latest gradio

Signed-off-by: Yi Dong <[email protected]>

* default greedy

Signed-off-by: Yi Dong <[email protected]>

* better chatbot

Signed-off-by: Yi Dong <[email protected]>

* handle preamble

Signed-off-by: Yi Dong <[email protected]>

* added chatbot training capablity

Signed-off-by: Yi Dong <[email protected]>

* added chatbot ui

Signed-off-by: Yi Dong <[email protected]>

* remove debug code

Signed-off-by: Yi Dong <[email protected]>

* default human

Signed-off-by: Yi Dong <[email protected]>

* use special token for roles

Signed-off-by: Yi Dong <[email protected]>

* special tokens

Signed-off-by: Yi Dong <[email protected]>

* fix name

Signed-off-by: Yi Dong <[email protected]>

* new chat dataset

Signed-off-by: Yi Dong <[email protected]>

* fix the system token

Signed-off-by: Yi Dong <[email protected]>

* upgrade gradio

Signed-off-by: Yi Dong <[email protected]>

* save the chat history

Signed-off-by: Yi Dong <[email protected]>

* update ui

Signed-off-by: root <[email protected]>

* update chat interface

Signed-off-by: Yi Dong <[email protected]>

* handles canonical form

Signed-off-by: Yi Dong <[email protected]>

* new sft chatbot

Signed-off-by: Yi Dong <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change format

Signed-off-by: Yi Dong <[email protected]>

* check extra_id in the tokenizer

Signed-off-by: Yi Dong <[email protected]>

* added vocab property check

Signed-off-by: Yi Dong <[email protected]>

* added missing file

Signed-off-by: Yi Dong <[email protected]>

---------

Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
4 people authored and hsiehjackson committed Jun 2, 2023
1 parent f736f60 commit 4e7afbb
Show file tree
Hide file tree
Showing 9 changed files with 646 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ web_server: False # whether launch the web inference server
share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
web_port: 9889 # the port number of the web server
web_port: 9889 # the port number of the web server
chat: False # use the chat interface
8 changes: 6 additions & 2 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.modules.common.megatron_web_server import get_demo
from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
Expand Down Expand Up @@ -277,9 +277,13 @@ def main(cfg) -> None:
if cfg.server:
if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
if cfg.chat:
web_ui = get_chatbot_demo
else:
web_ui = get_demo
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=get_demo,
target=web_ui,
daemon=True,
args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ model:
ffn_dropout: 0.0

data:
chat: False # whether use chatbot data or not
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import torch

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.utils import logging

__all__ = ['GPTSFTChatDataset']

IGNORE_INDEX = -100
END_SIGNAL = "\n"
END_NAME_SIGNAL = "\n"

SYSTEM_TOKEN = "<extra_id_0>System\n"
TURN_TOKEN = "<extra_id_1>"

GUARD_RAIL_INSTRUCTION = {
"TEXT_TO_CANONICAL_FORM": "Given a dialogue, for each turn you need to generate a short summary called a canonical form. Generate the canonical form for the last turn in the dialogue.",
"CANONICAL_FORM_TO_TEXT": "Given a dialogue, for each turn we also have a short summary called a canonical form. Generate the canonical form given the last turn message and canonical form. Then generate the message.",
}


def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer, mask_role):
cur_idx = header_len
tgt_len = target.shape[0]
for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)):
# note, sentence piece will add extra empty token in front. s_id has that extra token too
skip_name_len = len(tokenizer.text_to_ids(TURN_TOKEN + speaker + END_NAME_SIGNAL))
if cur_idx >= tgt_len:
break
elif cur_idx + tokenized_len < tgt_len:
# Check whether the mask is applied to the correct position, the first token is turn token: <extra_id_1>
# s_id[2:] skips the artifact empty token and the turn token
# target[cur_idx + 1:cur_idx + tokenized_len] skip the turn token
if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[2:]):
logging.warning("a sentence mismatches the corresponding piece " "in the conversation")
if i == 0:
# mask the first turn completely to provide at least one turn as context
target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX
elif speaker == mask_role:
# leave the first human tag unmasked
target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX
else:
# mask up to the name end, need to remove one as skip name has an extra artifact empty token
target[cur_idx : cur_idx + skip_name_len - 1] = IGNORE_INDEX
cur_idx += tokenized_len


def cannonical_form_formater(cannoical_form):
return f'<extra_id_2>{cannoical_form}\n'


def _add_speaker_and_signal(header, source, mask_role, gtype):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = ""
conversation = header
for i, sentence in enumerate(source):
sentence_from = sentence["from"]
role_token = TURN_TOKEN
if gtype is None:
sentence["value"] = (
BEGIN_SIGNAL + role_token + sentence_from + END_NAME_SIGNAL + sentence["value"] + END_SIGNAL
)
elif gtype == "TEXT_TO_CANONICAL_FORM":
sentence["value"] = (
BEGIN_SIGNAL
+ role_token
+ sentence_from
+ END_NAME_SIGNAL
+ sentence["value"]
+ END_SIGNAL
+ cannonical_form_formater(sentence['canonical_form'])
)
elif gtype == "CANONICAL_FORM_TO_TEXT":
sentence["value"] = (
BEGIN_SIGNAL
+ role_token
+ sentence_from
+ END_NAME_SIGNAL
+ cannonical_form_formater(sentence['canonical_form'])
+ sentence["value"]
+ END_SIGNAL
)
else:
raise ValueError(f"source type {gtype} not supported")
conversation += sentence["value"]
# if the last turn is not masked, add next token start token to the end, which will be included for loss calculation
if sentence_from != mask_role and i == len(source) - 1:
conversation += TURN_TOKEN
return conversation


def preprocess(
source: dict, tokenizer: TokenizerSpec,
):
"""
Given a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
canonical_type = None
if 'type' in source:
canonical_type = source['type']
assert canonical_type in GUARD_RAIL_INSTRUCTION, f"source type {canonical_type} not supported"
# add end signal and concatenate together
conversation = source['system']
if canonical_type is not None:
conversation = conversation + '\n' + GUARD_RAIL_INSTRUCTION[canonical_type]
mask_role = source.get('mask', 'User')
header = f"{SYSTEM_TOKEN}{conversation}\n\n"
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, canonical_type)
# tokenize conversations
input_ids = tokenizer.text_to_ids(conversation)
target = copy.deepcopy(input_ids)
header_len = len(tokenizer.text_to_ids(header))

ids = []
tokenized_lens = []
for s in source['conversations']:
tokenized_sentence = tokenizer.text_to_ids(s["value"])
ids.append(torch.tensor(tokenized_sentence))
# remove one token as it adds an empty token in front
tokenized_lens.append(len(tokenized_sentence) - 1)
speakers = [sentence["from"] for sentence in source['conversations']]
assert mask_role in speakers, "mask role not in the conversation"
target = torch.LongTensor(target)
# not going to train on the header
target[:header_len] = IGNORE_INDEX
input_ids = torch.LongTensor(input_ids)

_mask_targets(target, tokenized_lens, speakers, header_len, ids, tokenizer, mask_role)
mask = (target != IGNORE_INDEX).bool()
assert mask.sum().item() != 0, "mask is empty"
return dict(input_ids=input_ids, mask=mask)


class GPTSFTChatDataset(GPTSFTDataset):
def _build_samples_mapping(self):
super()._build_samples_mapping()
assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported"
assert '<extra_id_0>' in self.tokenizer.vocab, "<extra_id_0> not in the tokenizer vocab. not supported"
assert '<extra_id_1>' in self.tokenizer.vocab, "<extra_id_1> not in the tokenizer vocab. not supported"

def _process_example(self, example):
"""
Create an example by concatenating text and answer.
Truncation is carried out when needed, but it is performed only on the prompt side.
BOS, EOS, and SEP, are added if specified.
"""
result = preprocess(example, self.tokenizer)

return result

def collate_fn(self, batch):
input_ids = [item['input_ids'][:-1].tolist() for item in batch]
labels = [item['input_ids'][1:].tolist() for item in batch]
loss_mask = [item['mask'][1:].tolist() for item in batch]

max_length = max([len(x) for x in input_ids])
if max_length > self.max_seq_length:
# truncate the sequences if it is longer than max_seq_length
input_ids = [x[: self.max_seq_length] for x in input_ids]
labels = [x[: self.max_seq_length] for x in labels]
loss_mask = [x[: self.max_seq_length] for x in loss_mask]
# increase max length to nearest multiple of 4 or 8
if self.pad_to_max_length:
max_length = self.max_seq_length
else:
max_length = min(self.max_seq_length, self._round_to_nearest(max_length, 8))
assert max_length <= self.max_seq_length

attention_mask = [self._create_attention_mask(max_length) for _ in batch]
attention_mask = torch.stack(attention_mask)
position_ids = [list(range(max_length)) for _ in batch]
position_ids = torch.LongTensor(position_ids)
input_ids = torch.LongTensor(
self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id)
)
labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id))
loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0))

processed_batch = {
'tokens': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
}

return processed_batch
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_datasets_weights_and_num_samples,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
Expand Down Expand Up @@ -234,7 +235,11 @@ def _build_dataset(self, data_cfg, is_train=True):
num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names)

for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset):
dataset = GPTSFTDataset(
if self.cfg.data.chat:
dataset_cls = GPTSFTChatDataset
else:
dataset_cls = GPTSFTDataset
dataset = dataset_cls(
file_path=file_path,
tokenizer=self.tokenizer,
max_seq_length=data_cfg.max_seq_length,
Expand Down
84 changes: 84 additions & 0 deletions nemo/collections/nlp/modules/common/chat_css.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

CSS = """
#chatbot .hll { background-color: #ffffcc }
#chatbot .c { color: #408080; font-style: italic }
#chatbot .err { border: 1px solid #FF0000 }
#chatbot .k { color: #008000; font-weight: bold }
#chatbot .o { color: #666666 }
#chatbot .ch { color: #408080; font-style: italic }
#chatbot .cm { color: #408080; font-style: italic }
#chatbot .cp { color: #BC7A00 }
#chatbot .cpf { color: #408080; font-style: italic }
#chatbot .c1 { color: #408080; font-style: italic }
#chatbot .cs { color: #408080; font-style: italic }
#chatbot .gd { color: #A00000 }
#chatbot .ge { font-style: italic }
#chatbot .gr { color: #FF0000 }
#chatbot .gh { color: #000080; font-weight: bold }
#chatbot .gi { color: #00A000 }
#chatbot .go { color: #888888 }
#chatbot .gp { color: #000080; font-weight: bold }
#chatbot .gs { font-weight: bold }
#chatbot .gu { color: #800080; font-weight: bold }
#chatbot .gt { color: #0044DD }
#chatbot .kc { color: #008000; font-weight: bold }
#chatbot .kd { color: #008000; font-weight: bold }
#chatbot .kn { color: #008000; font-weight: bold }
#chatbot .kp { color: #008000 }
#chatbot .kr { color: #008000; font-weight: bold }
#chatbot .kt { color: #B00040 }
#chatbot .m { color: #666666 }
#chatbot .s { color: #BA2121 }
#chatbot .na { color: #7D9029 }
#chatbot .nb { color: #008000 }
#chatbot .nc { color: #0000FF; font-weight: bold }
#chatbot .no { color: #880000 }
#chatbot .nd { color: #AA22FF }
#chatbot .ni { color: #999999; font-weight: bold }
#chatbot .ne { color: #D2413A; font-weight: bold }
#chatbot .nf { color: #0000FF }
#chatbot .nl { color: #A0A000 }
#chatbot .nn { color: #0000FF; font-weight: bold }
#chatbot .nt { color: #008000; font-weight: bold }
#chatbot .nv { color: #19177C }
#chatbot .ow { color: #AA22FF; font-weight: bold }
#chatbot .w { color: #bbbbbb }
#chatbot .mb { color: #666666 }
#chatbot .mf { color: #666666 }
#chatbot .mh { color: #666666 }
#chatbot .mi { color: #666666 }
#chatbot .mo { color: #666666 }
#chatbot .sa { color: #BA2121 }
#chatbot .sb { color: #BA2121 }
#chatbot .sc { color: #BA2121 }
#chatbot .dl { color: #BA2121 }
#chatbot .sd { color: #BA2121; font-style: italic }
#chatbot .s2 { color: #BA2121 }
#chatbot .se { color: #BB6622; font-weight: bold }
#chatbot .sh { color: #BA2121 }
#chatbot .si { color: #BB6688; font-weight: bold }
#chatbot .sx { color: #008000 }
#chatbot .sr { color: #BB6688 }
#chatbot .s1 { color: #BA2121 }
#chatbot .ss { color: #19177C }
#chatbot .bp { color: #008000 }
#chatbot .fm { color: #0000FF }
#chatbot .vc { color: #19177C }
#chatbot .vg { color: #19177C }
#chatbot .vi { color: #19177C }
#chatbot .vm { color: #19177C }
#chatbot .il { color: #666666 }
"""
Loading

0 comments on commit 4e7afbb

Please sign in to comment.