Skip to content

Commit

Permalink
handle nv_dpo template text generation
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Feb 9, 2024
1 parent b63b190 commit 7f1f85c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
list_data_dict = []
if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm", "nv_dpo"]:
record = {
'system': 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'system': '\n' if multimodal_cfg["conv_template"]=='nv_dpo' else 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'conversations': [{'from': 'User', 'value': prompt}, {'from': 'Assistant', 'value': '',},],
}

Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,16 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para

clean_response = clean_text

if conv_template == "nvgpt":
if conv_template in ["nvgpt", "nv_steerlm"]:
labels_str_regexp = re.compile(f"<extra_id_2>quality:.*\n")
last_match_end_position = None
for match in re.finditer(labels_str_regexp, clean_response):
last_match_end_position = match.end()
if last_match_end_position is not None:
clean_response = clean_response[last_match_end_position:]
clean_response = clean_response.strip("<extra_id_1>")
elif conv_template == 'nv_dpo':
clean_response = clean_response.split("<extra_id_1>")[-2].strip().split("\n")[-1]
elif conv_template == "llama_2":
clean_response = clean_response.rsplit("[/INST] ", 1)[-1]
elif conv_template == "v1":
Expand Down

0 comments on commit 7f1f85c

Please sign in to comment.