Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions nemo_rl/data/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def math_data_processor(
add_generation_prompt=False,
add_special_tokens=False,
)
sys_prompt["token_ids"] = tokenizer(sys, return_tensors="pt")["input_ids"][0]
sys_prompt["token_ids"] = tokenizer(
sys, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
message_log.append(sys_prompt)

# user prompt
Expand Down Expand Up @@ -138,7 +140,9 @@ def multichoice_qa_processor(
add_generation_prompt=False,
add_special_tokens=False,
)
sys_prompt["token_ids"] = tokenizer(sys, return_tensors="pt")["input_ids"][0]
sys_prompt["token_ids"] = tokenizer(
sys, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
message_log.append(sys_prompt)

# user prompt
Expand All @@ -153,7 +157,9 @@ def multichoice_qa_processor(
add_generation_prompt=True,
add_special_tokens=False,
)
user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0]
user_message["token_ids"] = tokenizer(
message, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
user_message["content"] = message
message_log.append(user_message)

Expand Down
88 changes: 80 additions & 8 deletions tests/unit/data/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import sys
import tempfile
from collections import defaultdict

import pytest
Expand All @@ -25,6 +26,13 @@
from examples.run_grpo_math import hf_data_processor
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.datasets import AllTaskProcessedDataset
from nemo_rl.data.eval_datasets import (
AIME2024Dataset,
AIME2025Dataset,
GPQADataset,
MathDataset,
MMLUDataset,
)
from nemo_rl.data.hf_datasets.deepscaler import DeepScalerDataset
from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset
from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec
Expand Down Expand Up @@ -78,18 +86,15 @@ def test_math_data_processor():
],
)
@pytest.mark.parametrize(
"dataset_name",
"dataset_cls",
[
"openmathinstruct2",
"deepscaler",
OpenMathInstruct2Dataset,
DeepScalerDataset,
],
)
def test_math_hf_data_processor(tokenizer_name, dataset_name):
def test_math_hf_data_processor(tokenizer_name, dataset_cls):
# Initialize dataset
if dataset_name == "openmathinstruct2":
data = OpenMathInstruct2Dataset()
elif dataset_name == "deepscaler":
data = DeepScalerDataset()
data = dataset_cls()

# Setup tokenizer
tokenizer = get_tokenizer(
Expand Down Expand Up @@ -124,3 +129,70 @@ def test_math_hf_data_processor(tokenizer_name, dataset_name):
assert first_item is not None
assert "message_log" in first_item
assert len(first_item["message_log"]) > 0


@pytest.fixture
def system_prompt_file(request):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file:
file.write("You are a helpful assistant.\n{}")

return file.name


@pytest.mark.hf_gated
@pytest.mark.parametrize(
"tokenizer_name",
[
"meta-llama/Llama-3.2-1B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct", # no bos token
"google/gemma-3-1b-it",
"Qwen/Qwen3-0.6B", # no bos token
"deepseek-ai/DeepSeek-V3",
"moonshotai/Moonlight-16B-A3B-Instruct",
],
)
@pytest.mark.parametrize(
"dataset_cls",
[
MMLUDataset,
GPQADataset,
MathDataset,
AIME2024Dataset,
AIME2025Dataset,
],
)
@pytest.mark.parametrize(
"system_prompt_file", [system_prompt_file, None], indirect=True
)
def test_eval_math_hf_data_processor(tokenizer_name, dataset_cls, system_prompt_file):
# Initialize dataset
data = dataset_cls()

# Setup tokenizer
tokenizer = get_tokenizer(
TokenizerConfig(
name=tokenizer_name,
chat_template="default",
)
)

# Configure task specification
math_task_spec = TaskDataSpec(
task_name="math",
prompt_file=f"{os.path.dirname(abspath)}/../../../examples/prompts/cot.txt",
system_prompt_file=system_prompt_file,
)

dataset = AllTaskProcessedDataset(
dataset=data.rekeyed_ds,
tokenizer=tokenizer,
default_task_data_spec=math_task_spec,
task_data_processors=data.processor,
max_seq_length=128,
)

# Test that the first item can be retrieved when the BOS token assertion passes
first_item = dataset[0]
assert first_item is not None
assert "message_log" in first_item
assert len(first_item["message_log"]) > 0