Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix code format
Browse files Browse the repository at this point in the history
Signed-off-by: Huiying Li <willwin.lee@gmail.com>
HuiyingLi committed Feb 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d621241 commit 45ef3bd
Showing 3 changed files with 7 additions and 10 deletions.
9 changes: 2 additions & 7 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
@@ -488,7 +488,6 @@ def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict:

"""<extra_id_0>System\n\n<extra_id_1>User\n{user input}\n<extra_id_1>Assistant\n"""


conv = conversation_lib.conv_nv_dpo.copy()

# Apply prompt templates
@@ -513,7 +512,7 @@ def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict:
context = conv.get_prompt()
if strip_end_for_inference:
if context.endswith("\n<extra_id_1>"):
context = context[:-len("\n<extra_id_1>")] + "\n"
context = context[: -len("\n<extra_id_1>")] + "\n"
conversations.append(context)

add_extra_token = cfg.get("add_extra_token")
@@ -551,10 +550,8 @@ def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict:
cur_len += round_len
target[cur_len:] = IGNORE_INDEX


# Check if masking working correctly
#print(tokenizer.ids_to_text([a[0] for a in filter(lambda x:x[1]!=-1 ,[x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])]))
#print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])
# print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])

if add_extra_token:
tokens = tokens[:, :-1].contiguous()
@@ -563,8 +560,6 @@ def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict:
labels = torch.roll(labels, shifts=-1, dims=-1)
labels[:, -1] = IGNORE_INDEX



return dict(tokens=tokens, labels=labels,)


Original file line number Diff line number Diff line change
@@ -329,15 +329,17 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
DEFAULT_IMAGE_TOKEN,
preprocess_llama_2,
preprocess_multimodal,
preprocess_nvgpt,
preprocess_nv_dpo,
preprocess_nvgpt,
preprocess_v1,
)

list_data_dict = []
if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm", "nv_dpo"]:
record = {
'system': '\n' if multimodal_cfg["conv_template"]=='nv_dpo' else 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'system': '\n'
if multimodal_cfg["conv_template"] == 'nv_dpo'
else 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'conversations': [{'from': 'User', 'value': prompt}, {'from': 'Assistant', 'value': '',},],
}

Original file line number Diff line number Diff line change
@@ -190,7 +190,7 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para
clean_response = clean_response[last_match_end_position:]
clean_response = clean_response.strip("<extra_id_1>")
elif conv_template == 'nv_dpo':
clean_response = clean_response.split("<extra_id_1>")[-2][10:] #[10:] for removing "Assistant\n"
clean_response = clean_response.split("<extra_id_1>")[-2][10:] # [10:] for removing "Assistant\n"
elif conv_template == "llama_2":
clean_response = clean_response.rsplit("[/INST] ", 1)[-1]
elif conv_template == "v1":

0 comments on commit 45ef3bd

Please sign in to comment.