diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index a66403aefd..2b77112103 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -433,6 +433,7 @@ def get_formatted_message_log( add_eos_token: bool = True, add_generation_prompt: bool = False, tools: Optional[list[dict[str, Any]]] = None, + debug: bool = False, ) -> LLMMessageLogType: """Format and tokenize chat messages using the specified template. @@ -444,6 +445,7 @@ def get_formatted_message_log( add_eos_token: Whether to add eos token to last message if it is not already present. Default: True add_generation_prompt: Whether to include assistant's generation prompt in user messages. Default: False tools: Optional list of tool/function definitions to pass to the chat template. Default: None + debug: Whether to print debug information showing each message turn. Default: False Returns: The message log with updated 'token_ids' and 'content' fields. """ @@ -537,8 +539,8 @@ def _format_content_helper( ## pull out the chunk corresponding to the current message message_chunk = formatted_message[prev_message_len_no_eos:] - # Debug: Print each message turn separately (only once for the first sample) - if not hasattr(get_formatted_message_log, "_debug_printed"): + # Debug: Print each message turn separately + if debug: if i == 0: # Print header only at the start of first message print("\n" + "=" * 80) @@ -554,8 +556,6 @@ def _format_content_helper( print("-" * 40) if i == len(message_log_strs) - 1: - # Mark as printed after processing all turns of the first sample - get_formatted_message_log._debug_printed = True print("\n" + "=" * 80) print("DEBUG: Complete formatted conversation:") print("-" * 80) diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index 39b8fab49d..015280b484 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -792,3 +792,41 @@ def test_get_formatted_message_log_multimodal_prompt_formatting() -> None: isinstance(out[1]["token_ids"], torch.Tensor) and out[1]["token_ids"].numel() > 0 ) + + +def test_get_formatted_message_log_debug_off_by_default( + raw_chat_message_log: LLMMessageLogType, + capsys, +) -> None: + """Verify debug output is disabled by default.""" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + task_data_spec = TaskDataSpec(task_name="test") + + get_formatted_message_log( + raw_chat_message_log, + tokenizer, + task_data_spec, + ) + + captured = capsys.readouterr() + assert "DEBUG: Individual message turns" not in captured.out + + +def test_get_formatted_message_log_debug_enabled( + raw_chat_message_log: LLMMessageLogType, + capsys, +) -> None: + """Verify debug output is printed when debug=True.""" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + task_data_spec = TaskDataSpec(task_name="test") + + get_formatted_message_log( + raw_chat_message_log, + tokenizer, + task_data_spec, + debug=True, + ) + + captured = capsys.readouterr() + assert "DEBUG: Individual message turns from apply_chat_template" in captured.out + assert "DEBUG: Complete formatted conversation:" in captured.out