From b63b1900b74e03347378549e9fec99101a447a1d Mon Sep 17 00:00:00 2001 From: HuiyingLi Date: Wed, 7 Feb 2024 17:35:30 -0800 Subject: [PATCH] add nv_dpo conversation to accomendate empty system message Signed-off-by: HuiyingLi --- .../multimodal/data/neva/conversation.py | 15 +++++++++++++-- .../multimodal/data/neva/neva_dataset.py | 3 +-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/nemo/collections/multimodal/data/neva/conversation.py b/nemo/collections/multimodal/data/neva/conversation.py index 744329b47ed41..d51a5f973f99e 100644 --- a/nemo/collections/multimodal/data/neva/conversation.py +++ b/nemo/collections/multimodal/data/neva/conversation.py @@ -263,6 +263,17 @@ def dict(self): sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n", ) +conv_nv_dpo = Conversation( + system="\n", + roles=("User", "Assistant"), + version="nv_dpo", + messages=(), + offset=0, + sep_style=SeparatorStyle.NVGPT, + sep=DEFAULT_SEPARATOR_TOKEN, + sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n", +) + conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", @@ -400,8 +411,8 @@ def dict(self): "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "nvgpt": conv_nvgpt, - "nv_dpo": conv_nvgpt, - "nv_steerlm": conv_nvgpt + "nv_steerlm": conv_nvgpt, + "nv_dpo": conv_nv_dpo, } diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 5434392ce4820..ea778cd591452 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -487,7 +487,7 @@ def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict: """System\n\nUser\n{user input}\nAssistant\n""" - conv = conversation_lib.conv_nvgpt.copy() + conv = conversation_lib.conv_nv_dpo.copy() # Apply prompt templates conversations = [] @@ -543,7 +543,6 @@ def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict: instruction_len = len(tokenizer.text_to_ids(parts[0] + sep)) round_len = len(tokenizer.text_to_ids(rou + conv.sep)) - import pdb; pdb.set_trace() target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len