diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 38617460a5ad2..39f6b09d55244 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -512,7 +512,8 @@ def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict: conv.append_message(turn['from'], turn['value']) context = conv.get_prompt() if strip_end_for_inference: - context = context.rstrip("\n<extra_id_1>") + "\n" + if context.endswith("\n<extra_id_1>"): + context = context[:-len("\n<extra_id_1>")] + "\n" conversations.append(context) add_extra_token = cfg.get("add_extra_token")