diff --git a/examples/run_sft.py b/examples/run_sft.py index 8eb93b5adc..ce5b258b0c 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -109,6 +109,14 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): output_key=data_config["output_key"], prompt_file=data_config["prompt_file"], ) + elif data_cls == "openai_format": + data = hf_datasets.OpenAIFormatDataset( + data_config["train_data_path"], + data_config["val_data_path"], + data_config["chat_key"], + data_config["system_key"], + data_config["system_prompt"], + ) else: raise ValueError(f"Unknown dataset class: {data_cls}") print( diff --git a/nemo_rl/data/hf_datasets/__init__.py b/nemo_rl/data/hf_datasets/__init__.py index 54d4fd9c34..aa5596397c 100644 --- a/nemo_rl/data/hf_datasets/__init__.py +++ b/nemo_rl/data/hf_datasets/__init__.py @@ -15,6 +15,7 @@ from nemo_rl.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.data.hf_datasets.dpo import DPODataset from nemo_rl.data.hf_datasets.helpsteer3 import HelpSteer3Dataset +from nemo_rl.data.hf_datasets.oai_format_dataset import OpenAIFormatDataset from nemo_rl.data.hf_datasets.oasst import OasstDataset from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset from nemo_rl.data.hf_datasets.prompt_response_dataset import ( @@ -26,6 +27,7 @@ "DPODataset", "HelpSteer3Dataset", "OasstDataset", + "OpenAIFormatDataset", "OpenMathInstruct2Dataset", "PromptResponseDataset", "SquadDataset", diff --git a/nemo_rl/data/hf_datasets/oai_format_dataset.py b/nemo_rl/data/hf_datasets/oai_format_dataset.py new file mode 100644 index 0000000000..22d01346bc --- /dev/null +++ b/nemo_rl/data/hf_datasets/oai_format_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from datasets import load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +class OpenAIFormatDataset: + """This class is used to load an SFT dataset in the OpenAI format. + + The dataset should be in the following format: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."} + ] + } + system_key and system_prompt are optional. If provided, it will be added to the + beginning of the dataset. + chat_key should be the key of the messages list. Multi-turn conversations are + supported. + The last message in the conversation must be from the assistant. + """ + + def __init__( + self, + train_ds_path: str, + val_ds_path: str, + chat_key: str = "messages", + system_key: str = None, + system_prompt: str = None, + ): + self.chat_key = chat_key + self.system_key = system_key + self.system_prompt = system_prompt + train_original_dataset = load_dataset("json", data_files=train_ds_path)["train"] + val_original_dataset = load_dataset("json", data_files=val_ds_path)["train"] + + formatted_train_dataset = train_original_dataset.map(self.add_messages_key) + formatted_val_dataset = val_original_dataset.map(self.add_messages_key) + + self.formatted_ds = { + "train": formatted_train_dataset, + "validation": formatted_val_dataset, + } + + self.task_spec = TaskDataSpec( + "json_dataset", + ) + + def add_messages_key( + self, + example: dict[str, Any], + ) -> dict[str, list[dict[str, Any]]]: + messages = [message for message in example[self.chat_key]] + if self.system_key in example: + messages = [ + {"role": "system", "content": example[self.system_key]} + ] + messages + elif self.system_prompt: + messages = [{"role": "system", "content": self.system_prompt}] + messages + assert messages[-1]["role"] == "assistant" + return {"messages": messages} diff --git a/tests/unit/data/hf_datasets/test_oai_format_dataset.py b/tests/unit/data/hf_datasets/test_oai_format_dataset.py new file mode 100644 index 0000000000..4ba75a6a1d --- /dev/null +++ b/tests/unit/data/hf_datasets/test_oai_format_dataset.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile + +import pytest +from transformers import AutoTokenizer + +from nemo_rl.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES +from nemo_rl.data.hf_datasets.oai_format_dataset import ( + OpenAIFormatDataset, +) + + +@pytest.fixture +def sample_data(request): + chat_key = request.param[0] + system_key = request.param[1] + + train_data = { + chat_key: [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + ], + } + val_data = { + chat_key: [ + {"role": "user", "content": "What is the capital of Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + ], + } + + if system_key is not None: + train_data[system_key] = "You are a helpful assistant." + if system_key is not None: + val_data[system_key] = "You are a helpful assistant." + + # Create temporary files for train and validation data + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as train_file: + json.dump(train_data, train_file) + train_path = train_file.name + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as val_file: + json.dump(val_data, val_file) + val_path = val_file.name + + return train_path, val_path + + +@pytest.mark.parametrize("sample_data", [("messages", None)], indirect=True) +def test_dataset_initialization(sample_data): + train_path, val_path = sample_data + dataset = OpenAIFormatDataset(train_path, val_path) + + assert dataset.chat_key == "messages" + assert "train" in dataset.formatted_ds + assert "validation" in dataset.formatted_ds + + +@pytest.mark.parametrize("sample_data", [("conversations", None)], indirect=True) +def test_custom_keys(sample_data): + train_path, val_path = sample_data + dataset = OpenAIFormatDataset( + train_path, + val_path, + chat_key="conversations", + system_prompt="You are a helpful assistant.", + ) + + assert dataset.chat_key == "conversations" + assert dataset.system_prompt == "You are a helpful assistant." + + +@pytest.mark.parametrize("sample_data", [("messages", "system_key")], indirect=True) +def test_message_formatting(sample_data): + train_path, val_path = sample_data + dataset = OpenAIFormatDataset( + train_path, val_path, chat_key="messages", system_key="system_key" + ) + + first_example = dataset.formatted_ds["train"][0] + + assert first_example["messages"][0]["role"] == "system" + assert first_example["messages"][0]["content"] == "You are a helpful assistant." + assert first_example["messages"][1]["role"] == "user" + assert first_example["messages"][1]["content"] == "What is the capital of France?" + assert first_example["messages"][2]["role"] == "assistant" + assert first_example["messages"][2]["content"] == "The capital of France is Paris." + + chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response + tokenizer = AutoTokenizer.from_pretrained("Meta-Llama/Meta-Llama-3-8B-Instruct") + + combined_message = tokenizer.apply_chat_template( + first_example["messages"], + chat_template=chat_template, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + assert combined_message == "".join( + message["content"] for message in first_example["messages"] + )