diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index 04125c6f750e..c9eb013b64e9 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import datetime import os import threading from functools import partial @@ -167,7 +168,11 @@ def remove_padded_prompts(response, nb_paddings): def main(cfg) -> None: # trainer required for restoring model parallel models - trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer, callbacks=[CustomProgressBar()]) + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=[CustomProgressBar()], + ) if cfg.gpt_model_file is not None: if ( diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml index b0b8eedb5633..27e73996225f 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml @@ -71,6 +71,12 @@ model: data: chat: False # whether use chatbot data or not + chat_prompt_tokens: # special tokens for the chat prompts, a dictionary of {token_type: token}. note that some tokenizer may combine the characters at the junction between {end_of_turn}{turn_start}. e.g. '', the '><' sometimes is merged to be a single token. This is not supported, try to avoid + system_turn_start: '' + turn_start: '' + label_start: '' + end_of_turn: "\x0A" # \0x0A is '\n' + end_of_name: "\x0A" # \0x0A is '\n' train_ds: # Example of how to specify paths to multiple datasets # file_names: diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py index 9a70671f8073..4c4c001014db 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py @@ -15,11 +15,12 @@ import os import tempfile +import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector +from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import get_prompt_template_example from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel from nemo.collections.nlp.parts.nlp_overrides import ( @@ -36,6 +37,8 @@ from nemo.utils.exp_manager import exp_manager from nemo.utils.model_utils import inject_model_parallel_rank +mp.set_start_method("spawn", force=True) + def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): """ @@ -71,6 +74,13 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1) gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0) + if cfg.model.data.get('chat', False): + # chat model, overwrite the prompt template + prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens) + gpt_cfg.data.train_ds.prompt_template = prompt_template + gpt_cfg.data.validation_ds.prompt_template = prompt_template + gpt_cfg.data.test_ds.prompt_template = prompt_template + sft_cls = MegatronGPTSFTModel gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py index 801a58394f06..96cc57a300b8 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py @@ -16,19 +16,19 @@ import torch -from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset from nemo.utils import logging -__all__ = ['GPTSFTChatDataset'] +__all__ = ['GPTSFTChatDataset', 'get_prompt_template_example'] -IGNORE_INDEX = -100 -END_SIGNAL = "\n" -END_NAME_SIGNAL = "\n" -SYSTEM_TOKEN = "System\n" -TURN_TOKEN = "" +PREFIX_STR = ( + "\x00" # the prefix string used in the tokenizer to deal with the added empty token for some of the tokenizers +) + +IGNORE_INDEX = -100 +SYSTEM_TOKEN = "System" TYPE_INSTRUCTION = { 'TEXT_TO_VALUE': "", @@ -36,6 +36,56 @@ } +def _get_header_conversation_type_mask_role(source, special_tokens): + END_SIGNAL = special_tokens['end_of_turn'] + END_NAME_SIGNAL = special_tokens['end_of_name'] + + data_type = None + if 'type' in source: + data_type = source['type'] + if data_type is not None: + assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported" + # add end signal and concatenate together + conversation = source['system'] + if data_type is not None: + if TYPE_INSTRUCTION[data_type] != '': + conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type] + mask_role = source.get('mask', 'User') + header = f"{special_tokens['system_turn_start']}{SYSTEM_TOKEN}{END_NAME_SIGNAL}{conversation}{END_SIGNAL}" + conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type, special_tokens) + return header, conversation, data_type, mask_role + + +def get_prompt_template_example(special_tokens): + source = { + 'system': '{system message}', + 'conversations': [ + {'from': 'User', 'value': '{turn 1 user message}', 'label': None}, + {'from': 'Assistant', 'value': '{turn 1 assistant message}', 'label': '{turn 1 assistant label}'}, + {'from': 'User', 'value': '{turn 2 user message}', 'label': None}, + {'from': 'Assistant', 'value': '{turn 2 assistant message}', 'label': '{turn 2 assistant label}'}, + ], + "mask": "User", + "type": "VALUE_TO_TEXT", + } + _, conversation, _, _ = _get_header_conversation_type_mask_role(source, special_tokens) + return conversation + + +def identify_start_index_of_subsequence(subsequence, sequence): + """ find the location of the small tensor in the large tensor. + e.g. small = [1,3], large = [2,3,1,3], returns 2 + small = [3,2], large = [2,3,1,3], returns -1 + Args: + small (tensor): small tensor + large (tensor): large tensor + """ + for i in range(sequence.size(0) - subsequence.size(0) + 1): + if torch.equal(sequence[i : i + subsequence.size(0)], subsequence): + return i + return -1 + + def _mask_targets( target, tokenized_lens, @@ -45,8 +95,10 @@ def _mask_targets( tokenizer, mask_role, gtype, - extra_id_2_token_id, - new_line_token_id, + name_end_token_ids, + special_tokens, + label_start_ids, + num_turn_start_tokens, ): """ This function masks the tokens so the loss is computed only on the non-masked role's responses. For 'TEXT_TO_VALUE' type, the loss is computed on the value attributes. @@ -60,68 +112,88 @@ def _mask_targets( tokenizer (TokenizerSpec): tokenizer object mask_role (str): the speaker id to be masked from loss computation gtype (str): either 'TEXT_TO_VALUE' or 'VALUE_TO_TEXT' - extra_id_2_token_id (int): token id - new_line_token_id (int): new line token id - + name_end_token_ids (int): end of name token ids + special_tokens (dict): special tokens used for the chat prompt. It has the keys: system_turn_start, turn_start, label_start, end_of_turn + label_start_ids (list): list of label start token ids, + num_turn_start_tokens (int): number of tokens of the turn_start str """ + TURN_TOKEN = special_tokens['turn_start'] + END_NAME_SIGNAL = special_tokens['end_of_name'] + label_start_ids = torch.tensor(label_start_ids) + name_end_token_ids = torch.tensor(name_end_token_ids) + cur_idx = header_len tgt_len = target.shape[0] for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)): # note, sentence piece will add extra empty token in front. has to compute the diff - id1 = tokenizer.text_to_ids("") - id2 = tokenizer.text_to_ids("" + TURN_TOKEN + speaker + END_NAME_SIGNAL) - skip_name_len = len(id2) - len(id1) - if extra_id_2_token_id is None: - raise ValueError("extra_id_2 is not in the vocabulary") - if (s_id == extra_id_2_token_id).any().item(): + id1 = tokenizer.text_to_ids(PREFIX_STR) + id2 = tokenizer.text_to_ids(PREFIX_STR + TURN_TOKEN + speaker + END_NAME_SIGNAL) + skip_name_len = len(id2) - len( + id1 + ) # s_ids[:skip_name_len] is the name part of the prompt 'TURN_TOKEN + speaker + END_NAME_SIGNAL' + # get the position of the label start string in this turn + location = identify_start_index_of_subsequence(label_start_ids, s_id) + + if location >= 0: + # if it contains the label start tokens if gtype == 'VALUE_TO_TEXT': - # if contains the token - assert skip_name_len == torch.where((s_id == extra_id_2_token_id))[0].item() - # find new line token id 14 - more_skip_len = torch.where((s_id[skip_name_len:] == new_line_token_id))[0][0].item() + 1 + # handles the case that condition on labels to generate respone + # the next token after the name part of the prompt is the beginning of the label start tokens + assert skip_name_len == location + # find the first new line token after the label part, which indicates the end of the whole label string + # newline_loc = torch.where((s_id[skip_name_len:] == name_end_token_ids))[0] + newline_loc = identify_start_index_of_subsequence(name_end_token_ids, s_id[skip_name_len:]) + if newline_loc < 0: + # cannot find new line token, which means the the whole turn is just a partial label string. Mask the whole turn + target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX + continue + # skip the label part and the new line token + more_skip_len = newline_loc + len(name_end_token_ids) + # skip the name part and the label part skip_name_len += more_skip_len elif gtype == 'TEXT_TO_VALUE': - skip_name_len = torch.where((s_id == extra_id_2_token_id))[0].item() + 1 + # handles the case that condition on response to generate label + # skip the name part, response and the label start tokens part, the remainder is the label string without label start, e.g. 'quality:9,toxicity:8...' + skip_name_len = location + len(label_start_ids) if cur_idx >= tgt_len: break elif cur_idx + tokenized_len < tgt_len: - # Check whether the mask is applied to the correct position, the first token is turn token: - # s_id[2:] skips the artifact empty token and the turn token - # target[cur_idx + 1:cur_idx + tokenized_len] skip the turn token + # Check whether the mask is applied to the correct position, the first token is turn start tokens if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[1:]): logging.warning("a sentence mismatches the corresponding piece " "in the conversation") if i == 0 and (gtype == 'VALUE_TO_TEXT' or gtype is None): - # mask the first turn completely to provide at least one turn as context + # mask the first turn completely to provide at least one turn as context for the rest target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX elif speaker == mask_role and i == 1 and gtype == 'TEXT_TO_VALUE': - # leave the first human tag unmasked - target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX + # leave the first turn start tag unmasked, servers severs as the end of turn signal + target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX elif speaker == mask_role and (i > 1): - # leave the first human tag unmasked - target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX + # leave the first turn start tag unmasked, which severs as the end of turn signal + target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX elif speaker == mask_role and (i <= 1): # mask out everything in the second turn target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX else: - # mask up to the name end, need to remove one as skip name has an extra artifact empty token + # mask up to name part, label part for VALUE_TO_TEXT, or name part, response and label start tokens for TEXT_TO_VALUE, or just the name part if gtype is None target[cur_idx : cur_idx + skip_name_len] = IGNORE_INDEX cur_idx += tokenized_len -def cannonical_form_formater(cannoical_form): - return f'{cannoical_form}\n' - - -def response_value_formater(label): +def response_value_formater(label, label_start, end_signal): if isinstance(label, str): - return '' + label + '\n' + return label_start + label + end_signal elif label is None: return '' else: raise ValueError(f'Unknown label type {type(label)}, only str type is supported') -def _add_speaker_and_signal(header, source, mask_role, gtype): +def _add_speaker_and_signal(header, source, mask_role, gtype, special_tokens): + TURN_TOKEN = special_tokens['turn_start'] + END_SIGNAL = special_tokens['end_of_turn'] + LABEL_START = special_tokens['label_start'] + END_NAME_SIGNAL = special_tokens['end_of_name'] + """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "" conversation = header @@ -138,7 +210,11 @@ def _add_speaker_and_signal(header, source, mask_role, gtype): + role_token + sentence_from + END_NAME_SIGNAL - + (response_value_formater(sentence['label']) if 'label' in sentence else '') + + ( + response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL) + if 'label' in sentence + else '' + ) + sentence["value"] + END_SIGNAL ) @@ -150,7 +226,11 @@ def _add_speaker_and_signal(header, source, mask_role, gtype): + END_NAME_SIGNAL + sentence["value"] + END_SIGNAL - + (response_value_formater(sentence['label']) if 'label' in sentence else '') + + ( + response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL) + if 'label' in sentence + else '' + ) ) else: raise ValueError( @@ -163,7 +243,14 @@ def _add_speaker_and_signal(header, source, mask_role, gtype): return conversation -def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, new_line_token_id: int): +def preprocess( + source: dict, + tokenizer: TokenizerSpec, + name_end_token_ids: int, + label_start_ids: list, + special_tokens: dict, + num_turn_start_tokens: int, +): """ Given a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; @@ -171,36 +258,23 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ - data_type = None - if 'type' in source: - data_type = source['type'] - assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported" - # add end signal and concatenate together - conversation = source['system'] - if data_type is not None: - if TYPE_INSTRUCTION[data_type] != '': - conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type] - mask_role = source.get('mask', 'User') - header = f"{SYSTEM_TOKEN}{conversation}" - conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type) + header, conversation, data_type, mask_role = _get_header_conversation_type_mask_role(source, special_tokens) # tokenize conversations input_ids = tokenizer.text_to_ids(conversation) target = copy.deepcopy(input_ids) - header_len = len(tokenizer.text_to_ids(header)) + header_tokens = tokenizer.text_to_ids(header) + header_len = len(header_tokens) ids = [] tokenized_lens = [] + assert torch.equal(torch.tensor(target[:header_len]), torch.tensor(header_tokens)) for s in source['conversations']: - if isinstance(tokenizer, SentencePieceTokenizer): - tokenized_sentence = tokenizer.text_to_ids(s["value"]) - ids.append(torch.tensor(tokenized_sentence)[1:]) - # remove one token as it adds an empty token in front - tokenized_lens.append(len(tokenized_sentence) - 1) - else: - tokenized_sentence = tokenizer.text_to_ids(s["value"]) - ids.append(torch.tensor(tokenized_sentence)) - # remove one token as it adds an empty token in front - tokenized_lens.append(len(tokenized_sentence)) + # hack to remove the extra empty token in front + id1 = tokenizer.text_to_ids(PREFIX_STR + s["value"]) + id2 = tokenizer.text_to_ids(PREFIX_STR) + tokenized_sentence = id1[len(id2) :] + ids.append(torch.tensor(tokenized_sentence)) + tokenized_lens.append(len(tokenized_sentence)) speakers = [sentence["from"] for sentence in source['conversations']] assert mask_role in speakers, "mask role not in the conversation" target = torch.LongTensor(target) @@ -216,8 +290,10 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, tokenizer, mask_role, data_type, - extra_id_2_token_id, - new_line_token_id, + name_end_token_ids, + special_tokens, + label_start_ids, + num_turn_start_tokens, ) mask = (target != IGNORE_INDEX).bool() assert mask.sum().item() != 0, "mask is empty" @@ -228,14 +304,6 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, return dict(input_ids=input_ids, mask=mask, context_ids=context_ids, answer_ids=answer_ids) -def _check_token_in_vocab(tokenizer, token): - ids = tokenizer.text_to_ids(token) - if isinstance(tokenizer, SentencePieceTokenizer): - return len(ids) == 2 - else: - return len(ids) == 1 - - class GPTSFTChatDataset(GPTSFTDataset): def _maybe_validate_prompt_template(self): pass @@ -243,22 +311,20 @@ def _maybe_validate_prompt_template(self): def _build_samples_mapping(self): super()._build_samples_mapping() assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported" - assert _check_token_in_vocab( - self.tokenizer, '' - ), " not in the tokenizer vocab. not supported" - assert _check_token_in_vocab( - self.tokenizer, '' - ), " not in the tokenizer vocab. not supported" - # calcuilate id value - if _check_token_in_vocab(self.tokenizer, ''): - ids_1 = self.tokenizer.text_to_ids('') - ids_2 = self.tokenizer.text_to_ids('') - self.extra_id_2_token_id = ids_1[len(ids_2) :][0] - else: - self.extra_id_2_token_id = None - ids_1 = self.tokenizer.text_to_ids('\n') - ids_2 = self.tokenizer.text_to_ids('') - self.new_line_token_id = ids_1[len(ids_2) :][0] + LABEL_START = self.special_tokens['label_start'] + END_NAME_SIGNAL = self.special_tokens['end_of_name'] + + id1 = self.tokenizer.text_to_ids(PREFIX_STR) + id2 = self.tokenizer.text_to_ids(PREFIX_STR + LABEL_START) + self.label_start_tokens = id2[len(id1) :] + + id1 = self.tokenizer.text_to_ids(PREFIX_STR + END_NAME_SIGNAL) + id2 = self.tokenizer.text_to_ids(PREFIX_STR) + self.name_end_token_ids = id1[len(id2) :] + + id1 = self.tokenizer.text_to_ids(PREFIX_STR + self.special_tokens['turn_start']) + id2 = self.tokenizer.text_to_ids(PREFIX_STR) + self.num_turn_start_tokens = len(id1) - len(id2) def _process_example(self, example): """ @@ -266,7 +332,14 @@ def _process_example(self, example): Truncation is carried out when needed, but it is performed only on the prompt side. BOS, EOS, and SEP, are added if specified. """ - result = preprocess(example, self.tokenizer, self.extra_id_2_token_id, self.new_line_token_id) + result = preprocess( + example, + self.tokenizer, + self.name_end_token_ids, + self.label_start_tokens, + self.special_tokens, + self.num_turn_start_tokens, + ) # store metadata in dataset, in case user may have keys required in the prediction json files metadata = {k: v for k, v in example.items() if k not in ['conversations']} diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 101201ef7536..9c6e50f5e43f 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -13,10 +13,14 @@ # limitations under the License. import re -from typing import List, Optional +from typing import List, Mapping, Optional +import datasets import numpy as np import torch + +# hack to avoid the "not enough disk space" error in some slurm cluster +datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory='.': True from datasets import load_dataset from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -52,6 +56,7 @@ def __init__( memmap_workers: Optional[int] = None, hf_dataset: bool = False, truncation_method: str = 'right', + special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} ): """ file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} @@ -73,6 +78,7 @@ def __init__( prompt_template: Prompt template to inject via an fstring. Formatted like Q: {context_key}\n\nA: {label_key} hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. truncation_method: Truncation from which position. Options: ['left', 'right'] + special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} """ self.tokenizer = tokenizer self.file_path = file_path @@ -93,6 +99,16 @@ def __init__( self.virtual_tokens = virtual_tokens self.tokens_to_generate = tokens_to_generate self.truncation_method = truncation_method + if special_tokens is None: + self.special_tokens = { + "system_turn_start": "", + "turn_start": "", + "label_start": "", + "end_of_turn": "\n", + "end_of_name": "\n", + } + else: + self.special_tokens = special_tokens if hf_dataset: self.indexed_dataset = load_dataset( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 8f2a108fcfc0..84a1f52d1391 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -33,7 +33,6 @@ from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split from nemo.collections.nlp.modules.common.text_generation_utils import generate, get_computeprob_response - from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import AppState, logging @@ -296,9 +295,11 @@ def _build_dataset(self, data_cfg, is_train=True): truncation_method=data_cfg.get( 'truncation_method', 'right' ), # used to choose truncation method. Options: ['random', 'left', 'right'] + special_tokens=self.cfg.data.get( + 'chat_prompt_tokens', None + ), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} ) datasets.append(dataset) - if is_train: dataset = BlendableDataset( datasets=datasets, weights=data_cfg.concat_sampling_probabilities, size=num_train_samples_after_blend diff --git a/scripts/nlp_language_modeling/sft/data_clean.py b/scripts/nlp_language_modeling/sft/data_clean.py index 8c67aa2e3bcd..362f7edeeb3b 100644 --- a/scripts/nlp_language_modeling/sft/data_clean.py +++ b/scripts/nlp_language_modeling/sft/data_clean.py @@ -41,7 +41,7 @@ def data_clean( ) if library == 'huggingface': tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) - d = GPTSFTChatDataset(dataset_file, tokenizer, seq_len, 1) + d = GPTSFTChatDataset(dataset_file, tokenizer, seq_len, 1, hf_dataset=True) total_records = len(d) removed_ids = set() for i in range(total_records): diff --git a/scripts/nlp_language_modeling/sft/preprocessing.py b/scripts/nlp_language_modeling/sft/preprocessing.py index 7a08e055543d..175187a279bc 100644 --- a/scripts/nlp_language_modeling/sft/preprocessing.py +++ b/scripts/nlp_language_modeling/sft/preprocessing.py @@ -80,11 +80,12 @@ def parse_conversations(tree_obj): raise ValueError(f'unknown role {prompt_obj["role"]}') turn = {'value': prompt_obj['text'], 'from': role} if 'labels' in prompt_obj: - turn['human_labels'] = prompt_obj['labels'] - for key in turn['human_labels']: - value_set = label_values.get(key, set()) - value_set.add(turn['human_labels'][key]['value']) - label_values[key] = value_set + # remove human labels + # turn['human_labels'] = prompt_obj['labels'] + # for key in turn['human_labels']: + # value_set = label_values.get(key, set()) + # value_set.add(turn['human_labels'][key]['value']) + # label_values[key] = value_set turn['label'] = encode_labels(prompt_obj['labels']) if 'lang' in prompt_obj: turn['lang'] = prompt_obj['lang'].split('-')[0] diff --git a/tests/collections/nlp/test_chat_sft_dataset.py b/tests/collections/nlp/test_chat_sft_dataset.py index 36d00e3108d7..f7bcecaa3c28 100644 --- a/tests/collections/nlp/test_chat_sft_dataset.py +++ b/tests/collections/nlp/test_chat_sft_dataset.py @@ -16,13 +16,18 @@ import json import os import random +from functools import partial import pytest -from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset +from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import ( + GPTSFTChatDataset, + get_prompt_template_example, +) from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer TOKENIZER_FILE_43B = '/home/TestData/nlp/megatron_sft/tokenizer.model' +TOKENIZER_FILE_Llama2 = '/home/TestData/nlp/megatron_sft/llama2_tokenizer.model' MERGE_FILE = '/home/TestData/nlp/megatron_sft/merges.txt' VOCAB_FILE = '/home/TestData/nlp/megatron_sft/vocab.json' @@ -54,7 +59,7 @@ def create_data_points(mask_user, turn_num, records, temp_file, t2v, label=True) with open(temp_file, 'w', encoding='utf-8') as f: for r in range(records): record = {} - record['system'] = 'a chat\n\n' + record['system'] = 'a chat' record['type'] = 'TEXT_TO_VALUE' if t2v else 'VALUE_TO_TEXT' record['mask'] = 'User' if mask_user else 'Assistant' turns = [] @@ -74,244 +79,377 @@ def create_data_points(mask_user, turn_num, records, temp_file, t2v, label=True) class TestGPTSFTChatDataset: @classmethod def setup_class(cls): - pass + cls.special_tokens = { + "system_turn_start": "", + "turn_start": "", + "label_start": "", + "end_of_turn": "\n", + "end_of_name": "\n", + } + cls.suffix = cls.special_tokens['end_of_turn'] + cls.special_tokens['turn_start'] + cls.label_suffix = cls.special_tokens['end_of_name'] + cls.special_tokens['turn_start'] - @pytest.mark.unit - def test_43B_tokenizer_mask_user(self): + def _mask_user_test(self, tokenizer, ids_to_text): random.seed(5) temp_file = '/tmp/test_file.jsonl' turn_num = 5 records = 5 try: data_points = create_data_points(True, turn_num, records, temp_file, t2v=False) - tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + d = GPTSFTChatDataset( + temp_file, + tokenizer, + 4096, + 1, + index_mapping_dir='/tmp/', + hf_dataset=True, + special_tokens=self.special_tokens, + ) for i in range(len(d)): result = d[i] input_ids = result['input_ids'] mask = result['mask'] - text = tokenizer.ids_to_text(input_ids[mask].tolist()) + text = ids_to_text(input_ids[mask].tolist()) expected_text = '' for j in range(1, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + expected_text += data_points[i]['conversations'][j]['value'] + self.suffix assert text == expected_text finally: os.remove(temp_file) - @pytest.mark.unit - def test_43B_tokenizer_mask_assistant(self): + def _mask_assistant_test(self, tokenizer, ids_to_text): random.seed(3) temp_file = '/tmp/test_file.jsonl' turn_num = 5 records = 5 try: data_points = create_data_points(False, turn_num, records, temp_file, t2v=False) - tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + d = GPTSFTChatDataset( + temp_file, + tokenizer, + 4096, + 1, + index_mapping_dir='/tmp/', + hf_dataset=True, + special_tokens=self.special_tokens, + ) for i in range(len(d)): result = d[i] input_ids = result['input_ids'] mask = result['mask'] - text = tokenizer.ids_to_text(input_ids[mask].tolist()) + text = ids_to_text(input_ids[mask].tolist()) expected_text = '' for j in range(2, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + expected_text += data_points[i]['conversations'][j]['value'] + self.suffix assert text == expected_text finally: os.remove(temp_file) - @pytest.mark.unit - def test_43B_tokenizer_mask_user_t2v(self): + def _mask_user_t2v_test(self, tokenizer, ids_to_text): random.seed(5) temp_file = '/tmp/test_file.jsonl' turn_num = 5 records = 5 try: data_points = create_data_points(True, turn_num, records, temp_file, t2v=True) - tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + d = GPTSFTChatDataset( + temp_file, + tokenizer, + 4096, + 1, + index_mapping_dir='/tmp/', + hf_dataset=True, + special_tokens=self.special_tokens, + ) for i in range(len(d)): result = d[i] input_ids = result['input_ids'] mask = result['mask'] - text = tokenizer.ids_to_text(input_ids[mask].tolist()) + text = ids_to_text(input_ids[mask].tolist()) expected_text = '' for j in range(1, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' + expected_text += data_points[i]['conversations'][j]['label'] + self.label_suffix assert text == expected_text finally: os.remove(temp_file) - @pytest.mark.unit - def test_43B_tokenizer_mask_assistant_t2v(self): + def _mask_assistant_t2v_test(self, tokenizer, ids_to_text): random.seed(5) temp_file = '/tmp/test_file.jsonl' turn_num = 5 records = 5 try: data_points = create_data_points(False, turn_num, records, temp_file, t2v=True) - tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) + d = GPTSFTChatDataset( + temp_file, + tokenizer, + 4096, + 1, + index_mapping_dir='/tmp/', + hf_dataset=True, + special_tokens=self.special_tokens, + ) for i in range(len(d)): result = d[i] input_ids = result['input_ids'] mask = result['mask'] - text = tokenizer.ids_to_text(input_ids[mask].tolist()) + text = ids_to_text(input_ids[mask].tolist()) expected_text = '' for j in range(0, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' + expected_text += data_points[i]['conversations'][j]['label'] + self.label_suffix assert text == expected_text finally: os.remove(temp_file) - @pytest.mark.unit - def test_mpt_tokenizer_mask_user(self): + def _mask_user_nolabel_test(self, tokenizer, ids_to_text): random.seed(5) temp_file = '/tmp/test_file.jsonl' turn_num = 5 records = 5 try: - data_points = create_data_points(True, turn_num, records, temp_file, t2v=False) - tokenizer = get_nmt_tokenizer( - library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True - ) - tokenizer.add_special_tokens( - {'additional_special_tokens': ['', '', '']} + data_points = create_data_points(True, turn_num, records, temp_file, t2v=False, label=False) + d = GPTSFTChatDataset( + temp_file, + tokenizer, + 4096, + 1, + index_mapping_dir='/tmp/', + hf_dataset=True, + special_tokens=self.special_tokens, ) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) for i in range(len(d)): result = d[i] input_ids = result['input_ids'] mask = result['mask'] - text = ids_to_text(tokenizer, input_ids[mask].tolist()) + text = ids_to_text(input_ids[mask].tolist()) expected_text = '' for j in range(1, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + expected_text += data_points[i]['conversations'][j]['value'] + self.suffix assert text == expected_text finally: os.remove(temp_file) - @pytest.mark.unit - def test_mpt_tokenizer_mask_assistant(self): + def _mask_assistant_nolabel_test(self, tokenizer, ids_to_text): random.seed(3) temp_file = '/tmp/test_file.jsonl' turn_num = 5 records = 5 try: - data_points = create_data_points(False, turn_num, records, temp_file, t2v=False) - tokenizer = get_nmt_tokenizer( - library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True - ) - tokenizer.add_special_tokens( - {'additional_special_tokens': ['', '', '']} + data_points = create_data_points(False, turn_num, records, temp_file, t2v=False, label=False) + d = GPTSFTChatDataset( + temp_file, + tokenizer, + 4096, + 1, + index_mapping_dir='/tmp/', + hf_dataset=True, + special_tokens=self.special_tokens, ) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) for i in range(len(d)): result = d[i] input_ids = result['input_ids'] mask = result['mask'] - text = ids_to_text(tokenizer, input_ids[mask].tolist()) + text = ids_to_text(input_ids[mask].tolist()) expected_text = '' for j in range(2, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' + expected_text += data_points[i]['conversations'][j]['value'] + self.suffix assert text == expected_text finally: os.remove(temp_file) - @pytest.mark.unit - def test_mpt_tokenizer_mask_user_t2v(self): + def _test_example_prompt(self, tokenizer): random.seed(5) - temp_file = '/tmp/test_file.jsonl' - turn_num = 5 - records = 5 - try: - data_points = create_data_points(True, turn_num, records, temp_file, t2v=True) - tokenizer = get_nmt_tokenizer( - library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + conv = get_prompt_template_example(self.special_tokens) + expected = ( + self.special_tokens['system_turn_start'] + + 'System' + + self.special_tokens['end_of_name'] + + '{system message}' + + self.special_tokens['end_of_turn'] + ) + for turn in range(2): + expected += ( + self.special_tokens['turn_start'] + + 'User' + + self.special_tokens['end_of_name'] + + f'{{turn {turn + 1} user message}}' + + self.special_tokens['end_of_turn'] ) - tokenizer.add_special_tokens( - {'additional_special_tokens': ['', '', '']} + expected += self.special_tokens['turn_start'] + 'Assistant' + self.special_tokens['end_of_name'] + expected += ( + self.special_tokens['label_start'] + + f'{{turn {turn + 1} assistant label}}' + + self.special_tokens['end_of_name'] ) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) - for i in range(len(d)): - result = d[i] - input_ids = result['input_ids'] - mask = result['mask'] - text = ids_to_text(tokenizer, input_ids[mask].tolist()) - expected_text = '' - for j in range(1, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' - assert text == expected_text - finally: - os.remove(temp_file) + expected += f'{{turn {turn + 1} assistant message}}' + self.special_tokens['end_of_turn'] + expected += self.special_tokens['turn_start'] + assert conv == expected @pytest.mark.unit - def test_mpt_tokenizer_mask_assistant_t2v(self): - random.seed(5) - temp_file = '/tmp/test_file.jsonl' - turn_num = 5 - records = 5 - try: - data_points = create_data_points(False, turn_num, records, temp_file, t2v=True) - tokenizer = get_nmt_tokenizer( - library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True - ) - tokenizer.add_special_tokens( - {'additional_special_tokens': ['', '', '']} - ) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) - for i in range(len(d)): - result = d[i] - input_ids = result['input_ids'] - mask = result['mask'] - text = ids_to_text(tokenizer, input_ids[mask].tolist()) - expected_text = '' - for j in range(0, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['label'] + '\n' + '' - assert text == expected_text - finally: - os.remove(temp_file) + def test_43B_example_prompt(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + self._test_example_prompt(tokenizer) + + @pytest.mark.unit + def test_43B_tokenizer_mask_user(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + self._mask_user_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_43B_tokenizer_mask_assistant(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + self._mask_assistant_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_43B_tokenizer_mask_user_t2v(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + self._mask_user_t2v_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_43B_tokenizer_mask_assistant_t2v(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + self._mask_assistant_t2v_test(tokenizer, tokenizer.ids_to_text) @pytest.mark.unit def test_43B_tokenizer_mask_user_nolabel(self): - random.seed(5) - temp_file = '/tmp/test_file.jsonl' - turn_num = 5 - records = 5 - try: - data_points = create_data_points(True, turn_num, records, temp_file, t2v=False, label=False) - tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) - for i in range(len(d)): - result = d[i] - input_ids = result['input_ids'] - mask = result['mask'] - text = tokenizer.ids_to_text(input_ids[mask].tolist()) - expected_text = '' - for j in range(1, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' - assert text == expected_text - finally: - os.remove(temp_file) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + self._mask_user_nolabel_test(tokenizer, tokenizer.ids_to_text) @pytest.mark.unit def test_43B_tokenizer_mask_assistant_nolabel(self): - random.seed(3) - temp_file = '/tmp/test_file.jsonl' - turn_num = 5 - records = 5 - try: - data_points = create_data_points(False, turn_num, records, temp_file, t2v=False, label=False) - tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) - d = GPTSFTChatDataset(temp_file, tokenizer, 4096, 1, index_mapping_dir='/tmp/', hf_dataset=True) - for i in range(len(d)): - result = d[i] - input_ids = result['input_ids'] - mask = result['mask'] - text = tokenizer.ids_to_text(input_ids[mask].tolist()) - expected_text = '' - for j in range(2, turn_num, 2): - expected_text += data_points[i]['conversations'][j]['value'] + '\n' + '' - assert text == expected_text - finally: - os.remove(temp_file) + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B) + self._mask_assistant_nolabel_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_user(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) + self._mask_user_test(tokenizer, partial(ids_to_text, tokenizer)) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_assistant(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) + self._mask_assistant_test(tokenizer, partial(ids_to_text, tokenizer)) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_user_t2v(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) + self._mask_user_t2v_test(tokenizer, partial(ids_to_text, tokenizer)) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_assistant_t2v(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) + self._mask_assistant_t2v_test(tokenizer, partial(ids_to_text, tokenizer)) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_user_nolabel(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) + self._mask_user_nolabel_test(tokenizer, partial(ids_to_text, tokenizer)) + + @pytest.mark.unit + def test_mpt_tokenizer_mask_assistant_nolabel(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) + self._mask_assistant_nolabel_test(tokenizer, partial(ids_to_text, tokenizer)) + + @pytest.mark.unit + def test_llama2_tokenizer_mask_user(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_Llama2) + self._mask_user_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_llama2_tokenizer_mask_assistant(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_Llama2) + self._mask_assistant_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_llama2_tokenizer_mask_user_t2v(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_Llama2) + self._mask_user_t2v_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_llama2_tokenizer_mask_assistant_t2v(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_Llama2) + self._mask_assistant_t2v_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_llama2_tokenizer_mask_user_nolabel(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_Llama2) + self._mask_user_nolabel_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_llama2_tokenizer_mask_assistant_nolabel(self): + tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_Llama2) + self._mask_assistant_nolabel_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_normal_mpt_tokenizer_mask_user(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + self._mask_user_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_normal_mpt_tokenizer_mask_assistant(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + self._mask_assistant_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_normal_mpt_tokenizer_mask_user_t2v(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + self._mask_user_t2v_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_normal_mpt_tokenizer_mask_assistant_t2v(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + self._mask_assistant_t2v_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_normal_mpt_tokenizer_mask_user_nolabel(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + self._mask_user_nolabel_test(tokenizer, tokenizer.ids_to_text) + + @pytest.mark.unit + def test_normal_mpt_tokenizer_mask_assistant_nolabel(self): + tokenizer = get_nmt_tokenizer( + library='huggingface', model_name='gpt2', merges_file=MERGE_FILE, vocab_file=VOCAB_FILE, use_fast=True + ) + self._mask_assistant_nolabel_test(tokenizer, tokenizer.ids_to_text) + + +class TestDifferentGPTSFTChatDataset(TestGPTSFTChatDataset): + @classmethod + def setup_class(cls): + cls.special_tokens = { + "system_turn_start": "<|im_start|>", + "turn_start": "<|im_start|>", + "label_start": "<|label|>", + "end_of_turn": "<|im_end|>\n", + "end_of_name": "\n", + } + cls.suffix = cls.special_tokens['end_of_turn'] + cls.special_tokens['turn_start'] + cls.label_suffix = cls.special_tokens['end_of_name'] + cls.special_tokens['turn_start']