diff --git a/nemo/collections/multimodal/data/neva/conversation.py b/nemo/collections/multimodal/data/neva/conversation.py index 886049dd5170..d51a5f973f99 100644 --- a/nemo/collections/multimodal/data/neva/conversation.py +++ b/nemo/collections/multimodal/data/neva/conversation.py @@ -263,6 +263,17 @@ def dict(self): sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n", ) +conv_nv_dpo = Conversation( + system="\n", + roles=("User", "Assistant"), + version="nv_dpo", + messages=(), + offset=0, + sep_style=SeparatorStyle.NVGPT, + sep=DEFAULT_SEPARATOR_TOKEN, + sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n", +) + conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", @@ -400,6 +411,8 @@ def dict(self): "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "nvgpt": conv_nvgpt, + "nv_steerlm": conv_nvgpt, + "nv_dpo": conv_nv_dpo, } diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 90f862869369..15d755a7d59a 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -381,6 +381,8 @@ def preprocess_nvgpt(sources: dict, tokenizer, cfg,) -> Dict: - The function asserts that each message in a conversation alternates between the defined roles and skips messages not starting with the 'human' role. """ + """System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUser\n{user input}\nAssistant\nquality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n""" + conv = conversation_lib.conv_nvgpt.copy() # Apply prompt templates @@ -462,6 +464,105 @@ def preprocess_nvgpt(sources: dict, tokenizer, cfg,) -> Dict: return dict(tokens=tokens, labels=labels,) +def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict: + """ + Preprocess a given set of conversational sources using nvgpt conversation template + + This function processes conversations by first ensuring the conversation starts with a 'human' role, then tokenizes the conversations, applies specific token replacements, and finally masks labels for training purposes. + + Parameters: + - sources: A dictionary containing conversational data. Expected format is a dict of conversations, where each conversation is a list of messages, and each message is a dict with 'from' (role) and 'value' (message text). + - tokenizer: A tokenizer from the Hugging Face Transformers library used for tokenizing the conversations. + - cfg: Configuration settings which include 'add_extra_token' (bool) to determine if an extra token should be added to the tokenized output, and 'context_length' for specifying the tokenization context length. + + Returns: + - Dict: A dictionary containing two keys: + - 'tokens': A tensor of tokenized conversation data. + - 'labels': A tensor of labels for the conversation data, used for training models. Labels are masked based on the conversation structure. + + Note: + - The function includes specific token replacements (e.g., DEFAULT_IMAGE_PATCH_TOKEN, , ) and masking techniques for labels. + - It is designed to work with conversational data where messages alternate between a 'human' and a 'gpt' role. + - The function asserts that each message in a conversation alternates between the defined roles and skips messages not starting with the 'human' role. + """ + + """System\n\nUser\n{user input}\nAssistant\n""" + + conv = conversation_lib.conv_nv_dpo.copy() + + # Apply prompt templates + conversations = [] + for source in sources: + conv.messages = [] + conv.system = source.get('system', conv.system) + + strip_end_for_inference = False + for i, turn in enumerate(source['conversations']): + + if i % 2 == 1: + turn['from'] = conv.roles[1] + conv.append_message(turn['from'], turn['value']) + if not turn["value"]: + strip_end_for_inference = ( + True # in inference, current turn is empty, thus end tokens need to striped. + ) + else: + turn['from'] = conv.roles[0] + conv.append_message(turn['from'], turn['value']) + context = conv.get_prompt() + if strip_end_for_inference: + if context.endswith("\n"): + context = context[: -len("\n")] + "\n" + conversations.append(context) + + add_extra_token = cfg.get("add_extra_token") + # Tokenize conversations + tokens = tokenize( + texts=conversations, + tokenizer=tokenizer, + context_length=cfg.get("context_length"), + add_extra_token=add_extra_token, + ) + + labels = tokens.clone().detach() + + # Mask targets + sep = conv.sep + conv.roles[1] + "\n" + for conversation, target in zip(conversations, labels): + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt + + cur_len = 0 + for i, rou in enumerate(re_rounds): + if rou == "": + break + parts = rou.split(sep) + if len(parts) != 2: + break + + instruction_len = len(tokenizer.text_to_ids(parts[0] + sep)) + round_len = len(tokenizer.text_to_ids(rou + conv.sep)) + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + # Check if masking working correctly + # print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())]) + + if add_extra_token: + tokens = tokens[:, :-1].contiguous() + labels = labels[:, 1:].contiguous() + else: + labels = torch.roll(labels, shifts=-1, dims=-1) + labels[:, -1] = IGNORE_INDEX + + return dict(tokens=tokens, labels=labels,) + + def preprocess_plain(sources, tokenizer, cfg,) -> Dict: """ Preprocesses plain text sources (no template) for tokenization and label generation. @@ -604,8 +705,10 @@ def expand2square(pil_img, background_color): images_tensors = torch.tensor([]) sources = copy.deepcopy(sources) - if self.conv_template == "nvgpt": + if self.conv_template in ["nvgpt", "nv_steerlm"]: data_dict = preprocess_nvgpt(sources, self.tokenizer, self.multimodal_cfg,) + elif self.conv_template == "nv_dpo": + data_dict = preprocess_nv_dpo(sources, self.tokenizer, self.multimodal_cfg,) elif self.conv_template == "v1": data_dict = preprocess_v1(sources, self.tokenizer, self.multimodal_cfg,) elif self.conv_template == "llama_2": diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index fd68eef592fd..59452ce96f99 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -329,14 +329,17 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c DEFAULT_IMAGE_TOKEN, preprocess_llama_2, preprocess_multimodal, + preprocess_nv_dpo, preprocess_nvgpt, preprocess_v1, ) list_data_dict = [] - if multimodal_cfg["conv_template"] == "nvgpt": + if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm", "nv_dpo"]: record = { - 'system': 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n', + 'system': '\n' + if multimodal_cfg["conv_template"] == 'nv_dpo' + else 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n', 'conversations': [{'from': 'User', 'value': prompt}, {'from': 'Assistant', 'value': '',},], } @@ -348,7 +351,10 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c sources = preprocess_multimodal( copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents ) # HARDCODED FOR NOW - data_dict = preprocess_nvgpt(sources, tokenizer, multimodal_cfg) + if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm"]: + data_dict = preprocess_nvgpt(sources, tokenizer, multimodal_cfg) + else: + data_dict = preprocess_nv_dpo(sources, tokenizer, multimodal_cfg) elif multimodal_cfg["conv_template"] == "llama_2": record = { diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index c6a8f1e46900..7946b846c7cd 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -181,7 +181,7 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para clean_response = clean_text - if conv_template == "nvgpt": + if conv_template in ["nvgpt", "nv_steerlm"]: labels_str_regexp = re.compile(f"quality:.*\n") last_match_end_position = None for match in re.finditer(labels_str_regexp, clean_response): @@ -189,6 +189,8 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para if last_match_end_position is not None: clean_response = clean_response[last_match_end_position:] clean_response = clean_response.strip("") + elif conv_template == 'nv_dpo': + clean_response = clean_response.split("")[-2][10:] # [10:] for removing "Assistant\n" elif conv_template == "llama_2": clean_response = clean_response.rsplit("[/INST] ", 1)[-1] elif conv_template == "v1":