Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nemo_rl/data/llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def get_formatted_message_log(
add_eos_token: bool = True,
add_generation_prompt: bool = False,
tools: Optional[list[dict[str, Any]]] = None,
debug: bool = False,
) -> LLMMessageLogType:
"""Format and tokenize chat messages using the specified template.

Expand All @@ -444,6 +445,7 @@ def get_formatted_message_log(
add_eos_token: Whether to add eos token to last message if it is not already present. Default: True
add_generation_prompt: Whether to include assistant's generation prompt in user messages. Default: False
tools: Optional list of tool/function definitions to pass to the chat template. Default: None
debug: Whether to print debug information showing each message turn. Default: False
Returns:
The message log with updated 'token_ids' and 'content' fields.
"""
Expand Down Expand Up @@ -537,8 +539,8 @@ def _format_content_helper(
## pull out the chunk corresponding to the current message
message_chunk = formatted_message[prev_message_len_no_eos:]

# Debug: Print each message turn separately (only once for the first sample)
if not hasattr(get_formatted_message_log, "_debug_printed"):
# Debug: Print each message turn separately
if debug:
if i == 0:
# Print header only at the start of first message
print("\n" + "=" * 80)
Expand All @@ -554,8 +556,6 @@ def _format_content_helper(
print("-" * 40)

if i == len(message_log_strs) - 1:
# Mark as printed after processing all turns of the first sample
get_formatted_message_log._debug_printed = True
print("\n" + "=" * 80)
print("DEBUG: Complete formatted conversation:")
print("-" * 80)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/data/test_llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,41 @@ def test_get_formatted_message_log_multimodal_prompt_formatting() -> None:
isinstance(out[1]["token_ids"], torch.Tensor)
and out[1]["token_ids"].numel() > 0
)


def test_get_formatted_message_log_debug_off_by_default(
raw_chat_message_log: LLMMessageLogType,
capsys,
) -> None:
"""Verify debug output is disabled by default."""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
task_data_spec = TaskDataSpec(task_name="test")

get_formatted_message_log(
raw_chat_message_log,
tokenizer,
task_data_spec,
)

captured = capsys.readouterr()
assert "DEBUG: Individual message turns" not in captured.out


def test_get_formatted_message_log_debug_enabled(
raw_chat_message_log: LLMMessageLogType,
capsys,
) -> None:
"""Verify debug output is printed when debug=True."""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
task_data_spec = TaskDataSpec(task_name="test")

get_formatted_message_log(
raw_chat_message_log,
tokenizer,
task_data_spec,
debug=True,
)

captured = capsys.readouterr()
assert "DEBUG: Individual message turns from apply_chat_template" in captured.out
assert "DEBUG: Complete formatted conversation:" in captured.out
Loading