Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
77b03d7
make chat template configurable from config, save chat template as at…
ashors1 Apr 5, 2025
6116f77
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 7, 2025
4f17127
save hf tokenizer
ashors1 Apr 7, 2025
525296a
add sft example with json data
ashors1 Apr 4, 2025
ce04b9b
improved configurability
ashors1 Apr 7, 2025
1597d0f
fixes
ashors1 Apr 8, 2025
70f11a3
update grpo and clean up
ashors1 Apr 8, 2025
83552a6
fix unit tests
ashors1 Apr 8, 2025
23f06d2
address comments
ashors1 Apr 9, 2025
1816456
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 10, 2025
58fe349
add unit tests and documentation
ashors1 Apr 11, 2025
956c7f0
copyright header
ashors1 Apr 11, 2025
4913f25
address comments
ashors1 Apr 11, 2025
c016902
small fixes
ashors1 Apr 11, 2025
326b151
fix typo
ashors1 Apr 11, 2025
5068487
fix tests
ashors1 Apr 11, 2025
8e835ba
update chat template documentation
ashors1 Apr 15, 2025
0eeaa1e
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 15, 2025
d95d88b
fix unit tests
ashors1 Apr 15, 2025
1958bb9
fix doctest
ashors1 Apr 15, 2025
c536b55
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 16, 2025
2c7c5c5
fix checkpoint save when tokenizer not provided
ashors1 Apr 16, 2025
1df9a9f
feat: add a unique seed for each vllm llm engine (#171)
parthchadha Apr 15, 2025
f35ad95
fix: unit test script halts on first failure (#189)
terrykong Apr 15, 2025
844e470
fix new vllm test and doctest
ashors1 Apr 16, 2025
7e50e8e
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 16, 2025
ac4b6ea
remove old comment
ashors1 Apr 16, 2025
c5328f0
fix doctest
ashors1 Apr 16, 2025
5c4f849
Merge branch 'main' into ashors/chat-template-improvements
SahilJain314 Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ checkpointing:

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 4
generation_batch_size: 32 # Only used when generating using HF backend
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ grpo:

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
Expand Down
4 changes: 3 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ checkpointing:

policy:
model_name: "meta-llama/Llama-3.2-1B"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
chat_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 1024
Expand Down
4 changes: 1 addition & 3 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,13 @@ def openinstructmath2_data_processor(
problem = user_message[0]["content"]
extra_env_info = {"ground_truth": user_message[1]["content"]}

template = task_data_spec.custom_template
message_log: LLMMessageLogType = []
user_message = {
"role": "user",
"content": task_data_spec.prompt.format(problem),
}
message = tokenizer.apply_chat_template(
[user_message],
chat_template=template,
tokenize=False,
add_generation_prompt=True,
add_special_tokens=False,
Expand Down Expand Up @@ -254,7 +252,7 @@ def main():
init_ray()

# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["model_name"])
tokenizer = get_tokenizer(config["policy"]["tokenizer"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)
Expand Down
18 changes: 14 additions & 4 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def sft_preprocessor(
) -> DatumSpec:
"""Process a datum dictionary for SFT training."""
message_log = get_formatted_message_log(
datum_dict["messages"], tokenizer, task_data_spec
datum_dict["messages"],
tokenizer,
task_data_spec,
add_bos_token=True,
add_eos_token=True,
)

length = sum(len(m["token_ids"]) for m in message_log)
Expand Down Expand Up @@ -90,6 +94,13 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant")
elif data_cls == "squad":
data = hf_datasets.SquadDataset()
elif data_cls == "prompt_response_dataset":
data = hf_datasets.PromptResponseDataset(
data_config["train_data_path"],
data_config["val_data_path"],
data_config["input_key"],
data_config["output_key"],
)
else:
raise ValueError(f"Unknown dataset class: {data_cls}")
print(
Expand Down Expand Up @@ -151,7 +162,7 @@ def main():
init_ray()

# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["model_name"])
tokenizer = get_tokenizer(config["policy"]["tokenizer"])

# setup data
(
Expand All @@ -170,8 +181,7 @@ def main():
checkpointer,
sft_save_state,
master_config,
) = setup(config, dataset, val_dataset)

) = setup(config, tokenizer, dataset, val_dataset)
sft_train(
policy,
train_dataloader,
Expand Down
1 change: 1 addition & 0 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def setup(
policy = HfPolicy(
cluster=cluster,
config=policy_config,
tokenizer=tokenizer,
weights_path=Path(last_checkpoint_path) / "policy" / "weights"
if last_checkpoint_path
else None,
Expand Down
3 changes: 3 additions & 0 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from transformers import AutoTokenizer
from pathlib import Path
from typing import Optional, Tuple, TypedDict

Expand Down Expand Up @@ -76,6 +77,7 @@ class MasterConfig(TypedDict):
# =======================================================
def setup(
master_config: MasterConfig,
tokenizer: AutoTokenizer,
train_dataset: AllTaskProcessedDataset,
val_dataset: AllTaskProcessedDataset,
) -> Tuple[
Expand Down Expand Up @@ -175,6 +177,7 @@ def setup(
policy = HfPolicy(
cluster=cluster,
config=policy_config,
tokenizer=tokenizer,
weights_path=Path(last_checkpoint_path) / "policy" / "weights"
if last_checkpoint_path
else None,
Expand Down
14 changes: 12 additions & 2 deletions nemo_reinforcer/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from torch.masked import as_masked_tensor
from transformers import AutoTokenizer

from nemo_reinforcer.data import hf_datasets
from nemo_reinforcer.models.policy import TokenizerConfig


def calculate_kl_penalty_joschu2020(
logprobs_policy: torch.Tensor, logprobs_reference: torch.Tensor
Expand Down Expand Up @@ -133,9 +136,16 @@ def set_seed(seed: int):
torch.cuda.manual_seed_all(seed)


def get_tokenizer(model_name: str) -> AutoTokenizer:
def get_tokenizer(tokenizer_config: TokenizerConfig) -> AutoTokenizer:
"""Get the tokenizer and set pad token to eos token if it is not already set."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if "chat_template" in tokenizer_config:
if tokenizer_config["chat_template"] is None:
Comment thread
SahilJain314 marked this conversation as resolved.
tokenizer.chat_template = (
hf_datasets.COMMON_CHAT_TEMPLATES.passthrough_prompt_response
)
else:
tokenizer.chat_template = tokenizer_config["chat_template"]
return tokenizer
11 changes: 10 additions & 1 deletion nemo_reinforcer/data/hf_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_reinforcer.data.hf_datasets.prompt_response_dataset import (
PromptResponseDataset,
)
from nemo_reinforcer.data.hf_datasets.oasst import OasstDataset
from nemo_reinforcer.data.hf_datasets.squad import SquadDataset
from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES

__all__ = ["OasstDataset", "SquadDataset"]
__all__ = [
"OasstDataset",
"PromptResponseDataset",
"SquadDataset",
"COMMON_CHAT_TEMPLATES",
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Any, Optional
from nemo_reinforcer.data.interfaces import TaskDataSpec


## a reference to frequently used chat templates for convenience
class COMMON_CHAT_TEMPLATES:
### simple template which prepends a role header to the content
simple_role_header = "{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"


class HfDataset:
"""Interface for HuggingFace datasets."""

formatted_ds: Dict[str, Any]

def __init__(
self,
dataset_name: str,
custom_template: Optional[
str
] = None, ## "None" means use HuggingFace's tokenizer's template
):
self.task_spec = TaskDataSpec(
task_name=dataset_name,
custom_template=custom_template,
)
### passthrough template which just concatenates the content of the messages with no special tokens
passthrough_prompt_response = (
"{% for message in messages %}{{ message['content'] }}{% endfor %}"
)
10 changes: 5 additions & 5 deletions nemo_reinforcer/data/hf_datasets/oasst.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import copy
from dataclasses import dataclass
from typing import Optional
from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset, COMMON_CHAT_TEMPLATES

from nemo_reinforcer.data.interfaces import TaskDataSpec

SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"

Expand Down Expand Up @@ -119,10 +120,9 @@ def download_and_process_oasst(output_directory=".", seed=42, split_ratio=0.95):


@dataclass
Comment thread
ashors1 marked this conversation as resolved.
Outdated
class OasstDataset(HfDataset):
class OasstDataset:
def __init__(self, output_dir: str = "."):
self.formatted_ds = download_and_process_oasst(output_dir)
super().__init__(
dataset_name="oasst",
custom_template=COMMON_CHAT_TEMPLATES.simple_role_header,
self.task_spec = TaskDataSpec(
task_name="OASST",
)
9 changes: 5 additions & 4 deletions nemo_reinforcer/data/hf_datasets/openmathinstruct2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

from typing import Optional
from datasets import load_dataset
from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset
from dataclasses import dataclass

from nemo_reinforcer.data.interfaces import TaskDataSpec


def format_math(data):
return {
Expand Down Expand Up @@ -62,7 +63,7 @@ def prepare_openinstructmath2_dataset(split: str = "train_1M", seed=42, test_siz


@dataclass
class OpenMathInstruct2Dataset(HfDataset):
class OpenMathInstruct2Dataset:
def __init__(
self, split: str = "train_1M", seed: int = 42, test_size: float = 0.05
):
Expand All @@ -82,6 +83,6 @@ def __init__(
split=split, seed=seed, test_size=test_size
)

super().__init__(
dataset_name="OpenMathInstruct-2",
self.task_spec = TaskDataSpec(
task_name="OpenMathInstruct-2",
)
52 changes: 52 additions & 0 deletions nemo_reinforcer/data/hf_datasets/prompt_response_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset
from nemo_reinforcer.data.interfaces import TaskDataSpec


class PromptResponseDataset:
Comment thread
ashors1 marked this conversation as resolved.
def __init__(
self,
train_ds_path: str,
val_ds_path: str,
input_key: str = "input",
output_key: str = "output",
):
train_original_dataset = load_dataset("json", data_files=train_ds_path)["train"]
val_original_dataset = load_dataset("json", data_files=val_ds_path)["train"]

self.input_key = input_key
self.output_key = output_key

formatted_train_dataset = train_original_dataset.map(self.add_messages_key)
formatted_val_dataset = val_original_dataset.map(self.add_messages_key)

## just duplicating the dataset for train and validation for simplicity
self.formatted_ds = {
"train": formatted_train_dataset,
"validation": formatted_val_dataset,
}

self.task_spec = TaskDataSpec(
"json_dataset",
)

def add_messages_key(self, example):
return {
"messages": [
{"role": "user", "content": example[self.input_key]},
{"role": "assistant", "content": example[self.output_key]},
]
}
13 changes: 5 additions & 8 deletions nemo_reinforcer/data/hf_datasets/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from typing import Optional
from datasets import load_dataset
from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset

from nemo_reinforcer.data.interfaces import TaskDataSpec


def format_squad(data):
Expand All @@ -36,14 +37,10 @@ def format_squad(data):
}


class SquadDataset(HfDataset):
class SquadDataset:
def __init__(self):
original_ds = load_dataset("rajpurkar/squad")
self.formatted_ds = original_ds.map(format_squad)

custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"

super().__init__(
dataset_name="squad",
custom_template=custom_template,
self.task_spec = TaskDataSpec(
task_name="SQuAD",
)
2 changes: 0 additions & 2 deletions nemo_reinforcer/data/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class TaskDataSpec:
prompt_file: Optional[os.PathLike] = None

system_prompt_file: Optional[Union[str, os.PathLike]] = None
custom_template: Optional[Union[str, os.PathLike]] = None

def __post_init__(self):
def load_prompt_file(
Expand All @@ -66,7 +65,6 @@ def copy_defaults(self, from_spec: "TaskDataSpec"):
default_attrs = {
"system_prompt": from_spec.system_prompt,
"prompt": from_spec.prompt,
"custom_template": from_spec.custom_template,
}

for attr_name, default_value in default_attrs.items():
Expand Down
Loading