diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index 46c4d56197..ce68546993 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -32,129 +32,89 @@ uv run examples/run_dpo.py \ ## Datasets -Each class representing a NeMo RL DPO dataset is expected to have the following attributes: -1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below. -2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. - -DPO datasets are expected to follow a specific format with three key fields: -- `prompt`: The input prompt/context -- `chosen_response`: The preferred/winning response -- `rejected_response`: The non-preferred/losing response - -[data/hf_datasets/helpsteer3.py](../../nemo_rl/data/hf_datasets/helpsteer3.py) provides an example of how to format data for DPO: - -```python -def format_helpsteer3(data): - response_1 = data["response1"] - response_2 = data["response2"] - overall_preference = data["overall_preference"] - - if overall_preference < 0: - chosen = response_1 - rejected = response_2 - elif overall_preference == 0: - chosen = response_1 - rejected = response_1 - else: - chosen = response_2 - rejected = response_1 - - return { - "prompt": data["context"], - "chosen_response": chosen, - "rejected_response": rejected, +Each DPO dataset class is expected to have the following attributes: +1. `formatted_ds`: The dictionary of formatted datasets, where each dataset should be formatted like +```json +{ + "context": [], // list of dicts - The prompt message (including previous turns, if any) + "completions": [ // list of dicts — The list of completions + { + "rank": 0, // int — The rank of the completion (lower rank is preferred) + "completion": [] // list of dicts — The completion message(s) + }, + { + "rank": 1, // int — The rank of the completion (lower rank is preferred) + "completion": [] // list of dicts — The completion message(s) } + ] +} ``` +2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. -We also provide a [DPODataset](../../nemo_rl/data/hf_datasets/dpo.py) class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys. - -## Adding Custom DPO Datasets - -Adding a new DPO dataset is straightforward. Your custom dataset class should: -1. Implement the required format conversion in the constructor -2. Set up the appropriate `task_spec` - -Here's a minimal example which simply re-keys an existing jsonl dataset: - -```{testcode} -from datasets import load_dataset -from nemo_rl.data.interfaces import TaskDataSpec -from docs.helpers import make_dpo_dataset - -class CustomDPODataset: - def preprocess_dataset( - self, - data, - prompt_key: str = "context", - chosen_key: str = "chosen", - rejected_key: str = "rejected" - ): - return { - "prompt": data[prompt_key], - "chosen_response": data[chosen_key], - "rejected_response": data[rejected_key], +DPO training supports only two completions (where the lowest rank is preferred and the highest one is rejected), with each completion being a single response. For example: +```json +{ + "context": [ + { + "role": "user", + "content": "What's the capital of France?" + }, + { + "role": "assistant", + "content": "The capital of France is Paris." + }, + { + "role": "user", + "content": "Thanks! And what's the capital of Germany?" } - - def __init__( - self, - train_data_path: str, - val_data_path: str, - prompt_key: str, - chosen_key: str, - rejected_key: str, - ): - # Load and format your dataset - fn_kwargs={ - "prompt_key": prompt_key, - "chosen_key": chosen_key, - "rejected_key": rejected_key - } - formatted_ds = { - "train": load_dataset("json", data_files=train_data_path, split="train").map( - self.preprocess_dataset, - fn_kwargs=fn_kwargs, - ), - "validation": load_dataset("json", data_files=val_data_path, split="train").map( - self.preprocess_dataset, - fn_kwargs=fn_kwargs, - ), + ], + "completions": [ + { + "rank": 0, + "completion": [ + { + "role": "assistant", + "content": "The capital of Germany is Berlin." + } + ] + }, + { + "rank": 1, + "completion": [ + { + "role": "assistant", + "content": "The capital of Germany is Munich." + } + ] } - - # Initialize task spec with dataset name - self.task_spec = TaskDataSpec( - task_name="custom_dpo", - ) - self.formatted_ds = formatted_ds - -# Create temporary files using helper function -train_file, val_file = make_dpo_dataset() - -# Initialize dataset -dataset = CustomDPODataset( - train_data_path=train_file.name, - val_data_path=val_file.name, - prompt_key="context", - chosen_key="chosen", - rejected_key="rejected" -) - -# Test dataset properties -print(f"Task name: {dataset.task_spec.task_name}") -print(f"Train examples: {len(dataset.formatted_ds['train'])}") -print(f"Validation examples: {len(dataset.formatted_ds['validation'])}") -print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}") -print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}") -print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}") + ] +} ``` -```{testoutput} -Task name: custom_dpo -Train examples: 2 -Validation examples: 2 -First train example prompt: What is 2+2? -First train example chosen response: 4 -First train example rejected response: 5 +NeMo RL provides a DPO-compatible implementation of the [HelpSteer3](https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/data/hf_datasets/helpsteer3.py) dataset as an example. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. + +We also provide a [PreferenceDataset](../../nemo_rl/data/hf_datasets/preference_dataset.py) class that is compatible with JSONL-formatted preference datasets. You can modify your config as follows to use such a custom preference dataset: +```yaml +data: + dataset_name: PreferenceDataset + train_data_path: + val_data_paths: + : +``` +with support for multiple validation sets achieved with: +```yaml +data: + dataset_name: PreferenceDataset + train_data_path: + val_data_paths: + : + : ``` +Please note: +- If you are using a logger, the prefix used for each validation set will be `validation-`. The total validation time, summed across all validation sets, is reported under `timing/validation/total_validation_time`. +- If you are doing checkpointing, the `metric_name` value in your `checkpointing` config should reflect the metric and validation set to be tracked. For example, `validation-_loss`. + +The older [DPODataset](../../nemo_rl/data/hf_datasets/dpo.py) class is deprecated. This class is also compatible with JSONL-formatted preference datsets. It assumes train and validation datasets have been split and processed into the expected format offline. The JSONL files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys. ## DPO-Specific Parameters diff --git a/docs/guides/rm.md b/docs/guides/rm.md index 40f8ded1dc..f1843cd92c 100644 --- a/docs/guides/rm.md +++ b/docs/guides/rm.md @@ -21,4 +21,84 @@ The default YAML config shares the same base template as the SFT config but incl ## Datasets -By default, NeMo RL supports the `HelpSteer3` dataset. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. +Each RM dataset class is expected to have the following attributes: +1. `formatted_ds`: The dictionary of formatted datasets, where each dataset should be formatted like +```json +{ + "context": [], // list of dicts - The prompt message (including previous turns, if any) + "completions": [ // list of dicts — The list of completions + { + "rank": 0, // int — The rank of the completion (lower rank is preferred) + "completion": [] // list of dicts — The completion message(s) + }, + { + "rank": 1, // int — The rank of the completion (lower rank is preferred) + "completion": [] // list of dicts — The completion message(s) + } + ] +} +``` +2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. + +Currently, RM training supports only two completions (where the lowest rank is preferred and the highest one is rejected), with each completion being a single response. For example: +```json +{ + "context": [ + { + "role": "user", + "content": "What's the capital of France?" + }, + { + "role": "assistant", + "content": "The capital of France is Paris." + }, + { + "role": "user", + "content": "Thanks! And what's the capital of Germany?" + } + ], + "completions": [ + { + "rank": 0, + "completion": [ + { + "role": "assistant", + "content": "The capital of Germany is Berlin." + } + ] + }, + { + "rank": 1, + "completion": [ + { + "role": "assistant", + "content": "The capital of Germany is Munich." + } + ] + } + ] +} +``` + +NeMo RL provides a RM-compatible implementation of the [HelpSteer3](https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/data/hf_datasets/helpsteer3.py) dataset as an example. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. + +We also provide a [PreferenceDataset](../../nemo_rl/data/hf_datasets/preference_dataset.py) class that is compatible with JSONL-formatted preference datasets. You can modify your config as follows to use such a custom preference dataset: +```yaml +data: + dataset_name: PreferenceDataset + train_data_path: + val_data_paths: + : +``` +with support for multiple validation sets achieved with: +```yaml +data: + dataset_name: PreferenceDataset + train_data_path: + val_data_paths: + : + : +``` +Please note: +- If you are using a logger, the prefix used for each validation set will be `validation-`. The total validation time, summed across all validation sets, is reported under `timing/validation/total_validation_time`. +- If you are doing checkpointing, the `metric_name` value in your `checkpointing` config should reflect the metric and validation set to be tracked. For example, `validation-_loss`. \ No newline at end of file diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index cfe2c011e3..4a438e127e 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -151,9 +151,22 @@ policy: data_parallel_sharding_strategy: "optim_grads_params" data: - dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} shuffle: true + + dataset_name: HelpSteer3 + # You can use custom preference datasets for training and validation. For example: + # data: + # dataset_name: PreferenceDataset + # train_data_path: + # val_data_paths: + # : + # ... + # If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example: + # checkpointing: + # metric_name: "validation-_loss" + # ... + logger: log_dir: "logs" # Base directory for all logs wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index 46bb5ee910..744538d5ed 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -123,9 +123,21 @@ policy: data: max_input_seq_length: ${policy.max_total_sequence_length} - dataset_name: "HelpSteer3" shuffle: true + dataset_name: HelpSteer3 + # You can use custom preference datasets for training and validation. For example: + # data: + # dataset_name: PreferenceDataset + # train_data_path: + # val_data_paths: + # : + # ... + # If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example: + # checkpointing: + # metric_name: "validation-_loss" + # ... + logger: log_dir: "logs" # Base directory for all logs wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 885f957d74..b9b31cfcf6 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -69,9 +69,11 @@ def dpo_preprocessor( >>> task_spec = TaskDataSpec(task_name="test_dpo") >>> >>> datum = { - ... "prompt": "What is 2+2?", - ... "chosen_response": "4", - ... "rejected_response": "5" + ... "context": [{"role": "user", "content": "What is 2+2?"}], + ... "completions": [ + ... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, + ... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} + ... ] ... } >>> >>> processed = dpo_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) @@ -84,11 +86,13 @@ def dpo_preprocessor( >>> processed["message_log_rejected"][-1]["content"] '5<|eot_id|>' >>> - >>> # prompt can also be a list with multiple messages + >>> # context can also contain multiple turns >>> datum = { - ... "prompt": [{"role": "user", "content": "I have a question."}, {"role": "assistant", "content": "Sure!"}, {"role": "user", "content": "What is 2+2?"}], - ... "chosen_response": "4", - ... "rejected_response": "5" + ... "context": [{"role": "user", "content": "I have a question."}, {"role": "assistant", "content": "Sure!"}, {"role": "user", "content": "What is 2+2?"}], + ... "completions": [ + ... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, + ... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} + ... ] ... } >>> processed = dpo_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) >>> len(processed["message_log_chosen"]) @@ -102,36 +106,23 @@ def dpo_preprocessor( ``` """ - if isinstance(datum_dict["prompt"], list): - messages_chosen = datum_dict["prompt"].copy() - messages_rejected = datum_dict["prompt"].copy() - else: - messages_chosen = [ - { - "role": "user", - "content": datum_dict["prompt"], - }, - ] - messages_rejected = [ - { - "role": "user", - "content": datum_dict["prompt"], - }, - ] - - messages_chosen.append( - { - "role": "assistant", - "content": datum_dict["chosen_response"], - }, + assert len(datum_dict["completions"]) == 2, ( + "DPO training supports only two completions" ) + # Lower rank is preferred + if datum_dict["completions"][0]["rank"] < datum_dict["completions"][1]["rank"]: + chosen_completion = datum_dict["completions"][0] + rejected_completion = datum_dict["completions"][1] + elif datum_dict["completions"][0]["rank"] > datum_dict["completions"][1]["rank"]: + chosen_completion = datum_dict["completions"][1] + rejected_completion = datum_dict["completions"][0] + else: + raise NotImplementedError( + "Ties are not supported yet. You can use the following command to filter out ties: `cat | jq 'select(.completions[0].rank != .completions[1].rank)'`." + ) - messages_rejected.append( - { - "role": "assistant", - "content": datum_dict["rejected_response"], - }, - ) + messages_chosen = datum_dict["context"] + chosen_completion["completion"] + messages_rejected = datum_dict["context"] + rejected_completion["completion"] message_log_chosen = get_formatted_message_log( messages_chosen, tokenizer, task_data_spec @@ -173,22 +164,41 @@ def dpo_preprocessor( def setup_data(data_config: DataConfig, policy_config: PolicyConfig): print("\n▶ Setting up data...") + data_cls = data_config["dataset_name"] - if data_config["dataset_name"] == "HelpSteer3": + if data_cls == "PreferenceDataset": + data_path = data_config["train_data_path"] + data = hf_datasets.PreferenceDataset(data_path, split="train") + train_dataset = data.formatted_ds["train"] + val_dataset = None + elif data_cls == "HelpSteer3": data = hf_datasets.HelpSteer3Dataset() train_dataset = data.formatted_ds["train"] val_dataset = data.formatted_ds["validation"] - elif data_config["dataset_name"] == "Tulu3Preference": + elif data_cls == "Tulu3Preference": data = hf_datasets.Tulu3PreferenceDataset() train_dataset = data.formatted_ds["train"] val_dataset = None - else: + elif data_cls == "DPODataset": data = hf_datasets.DPODataset( train_data_path=data_config["train_data_path"], val_data_path=data_config["val_data_path"], ) train_dataset = data.formatted_ds["train"] val_dataset = data.formatted_ds["validation"] + else: + raise ValueError( + f"Unknown dataset class: {data_cls}. Supported datasets are: PreferenceDataset, HelpSteer3, Tulu3Preference, and DPODataset (deprecated)." + ) + + if train_dataset: + print( + f" ✓ Training dataset loaded with {len(data.formatted_ds['train'])} samples." + ) + if val_dataset: + print( + f" ✓ Validation dataset loaded with {len(data.formatted_ds['validation'])} samples." + ) dpo_task_spec = data.task_spec @@ -201,13 +211,48 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig): max_seq_length=data_config["max_input_seq_length"], ) - if val_dataset: - val_dataset = AllTaskProcessedDataset( - val_dataset, - tokenizer, - dpo_task_spec, - dpo_preprocessor, - max_seq_length=data_config["max_input_seq_length"], + if data_cls == "PreferenceDataset": + val_dataset = {} + + assert "val_data_path" not in data_config, ( + "`val_data_path` cannot be provided for PreferenceDataset. You should use `val_data_paths` instead." + ) + assert "val_data_paths" in data_config, ( + "`val_data_paths` must be provided for PreferenceDataset" + ) + assert isinstance(data_config["val_data_paths"], dict), ( + f"Invalid type for val_data_paths: {type(data_config['val_data_paths'])}. val_data_paths must be a dictionary." + ) + val_data_paths = data_config["val_data_paths"] + + for val_dataset_name, val_dataset_path in val_data_paths.items(): + assert val_dataset_name not in val_dataset + val_data = hf_datasets.PreferenceDataset( + val_dataset_path, split="validation" + ) + print( + f" ✓ Validation dataset '{val_dataset_name}' loaded with {len(val_data.formatted_ds['validation'])} samples." + ) + val_dataset[val_dataset_name] = AllTaskProcessedDataset( + val_data.formatted_ds["validation"], + tokenizer, + val_data.task_spec, + dpo_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + else: + val_dataset = ( + { + "default": AllTaskProcessedDataset( + val_dataset, + tokenizer, + dpo_task_spec, + dpo_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + } + if val_dataset + else {} ) return train_dataset, val_dataset, tokenizer, dpo_task_spec diff --git a/examples/run_rm.py b/examples/run_rm.py index 6586d8edb7..0adf84490d 100644 --- a/examples/run_rm.py +++ b/examples/run_rm.py @@ -56,12 +56,23 @@ def rm_preprocessor( idx: int, ) -> DatumSpec: """Process a datum dictionary for RM training.""" - messages_chosen = datum_dict["prompt"] + [ - {"role": "assistant", "content": datum_dict["chosen_response"]} - ] - messages_rejected = datum_dict["prompt"] + [ - {"role": "assistant", "content": datum_dict["rejected_response"]} - ] + assert len(datum_dict["completions"]) == 2, ( + "RM training supports only two completions" + ) + # Lower rank is preferred + if datum_dict["completions"][0]["rank"] < datum_dict["completions"][1]["rank"]: + chosen_completion = datum_dict["completions"][0] + rejected_completion = datum_dict["completions"][1] + elif datum_dict["completions"][0]["rank"] > datum_dict["completions"][1]["rank"]: + chosen_completion = datum_dict["completions"][1] + rejected_completion = datum_dict["completions"][0] + else: + raise NotImplementedError( + "Ties are not supported yet. You can use the following command to filter out ties: `cat | jq 'select(.completions[0].rank != .completions[1].rank)'`." + ) + + messages_chosen = datum_dict["context"] + chosen_completion["completion"] + messages_rejected = datum_dict["context"] + rejected_completion["completion"] message_log_chosen = get_formatted_message_log( messages_chosen, tokenizer, task_data_spec @@ -111,16 +122,33 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): print("\n▶ Setting up data...") data_cls = data_config["dataset_name"] - if data_cls == "HelpSteer3": + if data_cls == "PreferenceDataset": + data_path = data_config["train_data_path"] + data = hf_datasets.PreferenceDataset(data_path, split="train") + train_dataset = data.formatted_ds["train"] + val_dataset = None + elif data_cls == "HelpSteer3": data = hf_datasets.HelpSteer3Dataset() + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] + elif data_cls == "Tulu3Preference": + data = hf_datasets.Tulu3PreferenceDataset() + train_dataset = data.formatted_ds["train"] + val_dataset = None else: - raise ValueError(f"Unknown dataset class: {data_cls}") - print( - f" ✓ Training and validation datasets loaded with {len(data.formatted_ds['train'])} and {len(data.formatted_ds['validation'])} samples, respectively." - ) + raise ValueError( + f"Unknown dataset class: {data_cls}. Supported datasets are: PreferenceDataset, HelpSteer3, and Tulu3Preference." + ) + + if train_dataset: + print( + f" ✓ Training dataset loaded with {len(data.formatted_ds['train'])} samples." + ) + if val_dataset: + print( + f" ✓ Validation dataset loaded with {len(data.formatted_ds['validation'])} samples." + ) - train_dataset = data.formatted_ds["train"] - val_dataset = data.formatted_ds["validation"] rm_task_spec = data.task_spec train_dataset = AllTaskProcessedDataset( @@ -131,13 +159,49 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): max_seq_length=data_config["max_input_seq_length"], ) - val_dataset = AllTaskProcessedDataset( - val_dataset, - tokenizer, - rm_task_spec, - rm_preprocessor, - max_seq_length=data_config["max_input_seq_length"], - ) + if data_cls == "PreferenceDataset": + val_dataset = {} + + assert "val_data_path" not in data_config, ( + "`val_data_path` cannot be provided for PreferenceDataset. You should use `val_data_paths` instead." + ) + assert "val_data_paths" in data_config, ( + "`val_data_paths` must be provided for PreferenceDataset" + ) + assert isinstance(data_config["val_data_paths"], dict), ( + f"Invalid type for val_data_paths: {type(data_config['val_data_paths'])}. val_data_paths must be a dictionary." + ) + val_data_paths = data_config["val_data_paths"] + + for val_dataset_name, val_dataset_path in val_data_paths.items(): + assert val_dataset_name not in val_dataset + val_data = hf_datasets.PreferenceDataset( + val_dataset_path, split="validation" + ) + print( + f" ✓ Validation dataset '{val_dataset_name}' loaded with {len(val_data.formatted_ds['validation'])} samples." + ) + val_dataset[val_dataset_name] = AllTaskProcessedDataset( + val_data.formatted_ds["validation"], + tokenizer, + val_data.task_spec, + rm_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + else: + val_dataset = ( + { + "default": AllTaskProcessedDataset( + val_dataset, + tokenizer, + rm_task_spec, + rm_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + } + if val_dataset + else {} + ) return train_dataset, val_dataset, rm_task_spec diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 32978a3e83..579099c530 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -16,7 +16,7 @@ from collections import defaultdict from functools import partial from pathlib import Path -from typing import NotRequired, Optional, TypedDict, cast +from typing import Optional, TypedDict, cast import numpy as np import torch @@ -28,7 +28,7 @@ ) from nemo_rl.algorithms.utils import set_seed from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, dpo_collate_fn +from nemo_rl.data.datasets import AllTaskProcessedDataset, preference_collate_fn from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import PolicyInterface @@ -43,7 +43,6 @@ class DPOSaveState(TypedDict): epoch: int # Track current epoch step: int # Track step within current epoch total_steps: int # Track total number of steps across all epochs - val_loss: NotRequired[float] # Optional field - may not be present during training consumed_samples: int @@ -86,6 +85,11 @@ class MasterConfig(TypedDict): checkpointing: CheckpointingConfig +class DPOValMetrics(TypedDict): + loss: float + accuracy: float + + # ======================================================= # Setup & Initialization # ======================================================= @@ -93,12 +97,12 @@ def setup( master_config: MasterConfig, tokenizer: AutoTokenizer, train_dataset: AllTaskProcessedDataset, - val_dataset: AllTaskProcessedDataset, + val_dataset: dict[str, AllTaskProcessedDataset], ) -> tuple[ Policy, RayVirtualCluster, StatefulDataLoader, - StatefulDataLoader, + dict[str, StatefulDataLoader], DPOLossFn, Logger, CheckpointManager, @@ -154,11 +158,12 @@ def setup( batch_size=policy_config["train_global_batch_size"], shuffle=data_config["shuffle"], collate_fn=partial( - dpo_collate_fn, + preference_collate_fn, tokenizer=tokenizer, make_sequence_length_divisible_by=policy_config[ "make_sequence_length_divisible_by" ], + add_loss_mask=True, ), drop_last=True, ) @@ -169,19 +174,23 @@ def setup( ) train_dataloader.load_state_dict(dataloader_state_dict) - val_dataloader = StatefulDataLoader( - val_dataset, - batch_size=dpo_config["val_global_batch_size"], - shuffle=False, - collate_fn=partial( - dpo_collate_fn, - tokenizer=tokenizer, - make_sequence_length_divisible_by=policy_config[ - "make_sequence_length_divisible_by" - ], - ), - drop_last=True, - ) + val_dataloader = { + k: StatefulDataLoader( + v, + batch_size=dpo_config["val_global_batch_size"], + shuffle=False, + collate_fn=partial( + preference_collate_fn, + tokenizer=tokenizer, + make_sequence_length_divisible_by=policy_config[ + "make_sequence_length_divisible_by" + ], + add_loss_mask=True, + ), + drop_last=True, + ) + for k, v in val_dataset.items() + } # ========================== # Cluster @@ -266,7 +275,7 @@ def add_ref_logprobs_to_data(dataloader, policy, master_config, is_val=False): # ======================================================= def validate( policy: PolicyInterface, - val_dataloader: StatefulDataLoader, + val_dataloader: dict[str, StatefulDataLoader], tokenizer, loss_fn, step: int, @@ -274,8 +283,56 @@ def validate( val_batches: int, val_batch_size: int, val_mbs: int, + logger: Logger, ): - """Run validation on the validation dataset.""" + val_metrics, validation_timings = {}, {} + for val_dataset_name, v in val_dataloader.items(): + k_val_metrics, k_validation_timings = validate_one_dataset( + policy=policy, + val_dataloader=v, + loss_fn=loss_fn, + step=step, + master_config=master_config, + val_batches=val_batches, + val_batch_size=val_batch_size, + val_mbs=val_mbs, + dataset_name=val_dataset_name, + ) + prefix = f"validation-{val_dataset_name}" + + logger.log_metrics(k_val_metrics, step, prefix=prefix) + logger.log_metrics(k_validation_timings, step, prefix=f"timing/{prefix}") + + for metric_name in DPOValMetrics.__annotations__.keys(): + val_metrics[f"{prefix}_{metric_name}"] = k_val_metrics[metric_name] + validation_timings[prefix + "_total_validation_time"] = k_validation_timings[ + "total_validation_time" + ] + + if len(validation_timings) > 0: + total_validation_time = sum(validation_timings.values()) + logger.log_metrics( + {"total_validation_time": total_validation_time}, + step, + prefix="timing/validation", + ) + validation_timings["total_validation_time"] = total_validation_time + + return val_metrics, validation_timings + + +def validate_one_dataset( + policy: PolicyInterface, + val_dataloader: StatefulDataLoader, + loss_fn, + step: int, + master_config: MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + dataset_name: str, +): + """Run validation on one validation dataset.""" if val_dataloader is None: print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -283,7 +340,7 @@ def validate( timer = Timer() with timer.time("total_validation_time"): - print(f"▶ Starting validation at step {step}...") + print(f"▶ Starting validation at step {step} for `{dataset_name}` set..") val_metrics = defaultdict(lambda: 0.0) num_valid_batches = 0 @@ -336,12 +393,12 @@ def validate( else: # Print summary of validation results - print("\n📊 Validation Results:") - print(f" • Validation loss: {float(val_metrics['loss']):.4f}") - print(f" • Validation accuracy: {float(val_metrics['accuracy']):.4f}") + print(f"\n📊 Validation Results for `{dataset_name}` set:") + for metric_name in DPOValMetrics.__annotations__.keys(): + print(f" • Validation {metric_name}: {val_metrics[metric_name]:.4f}") # Print timing information - print("\n ⏱️ Validation Timing:") + print(f"\n ⏱️ Validation Timing for `{dataset_name}` set:") validation_time = timing_metrics.get("total_validation_time", 0) print(f" • Total validation time: {validation_time:.2f}s") @@ -399,15 +456,13 @@ def dpo_train( val_batches=dpo_config["val_batches"], val_batch_size=dpo_config["val_global_batch_size"], val_mbs=dpo_config["val_micro_batch_size"], + logger=logger, ) if validation_result is not None: val_metrics, validation_timings = validation_result else: val_metrics, validation_timings = None, None - logger.log_metrics(val_metrics, total_steps, prefix="validation") - logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") - policy.prepare_for_training() while ( @@ -455,17 +510,12 @@ def dpo_train( val_batches=dpo_config["val_batches"], val_batch_size=dpo_config["val_global_batch_size"], val_mbs=dpo_config["val_micro_batch_size"], + logger=logger, ) if validation_result is not None: val_metrics, validation_timings = validation_result else: val_metrics, validation_timings = None, None - logger.log_metrics( - validation_timings, total_steps + 1, prefix="timing/validation" - ) - logger.log_metrics( - val_metrics, total_steps + 1, prefix="validation" - ) ## Checkpointing dpo_save_state["consumed_samples"] += master_config["policy"][ @@ -488,10 +538,22 @@ def dpo_train( dpo_save_state["step"] = (current_step + 1) % len(train_dataloader) dpo_save_state["total_steps"] = total_steps + 1 dpo_save_state["epoch"] = current_epoch + # Remove outdated validation metrics + for key in list(dpo_save_state): + if ( + key.startswith("val") + and any( + [ + key.endswith(f"_{metric_name}") + for metric_name in DPOValMetrics.__annotations__.keys() + if metric_name != "num_valid_samples" + ] + ) + and (val_metrics is None or key not in val_metrics) + ): + del dpo_save_state[key] if val_metrics is not None: - dpo_save_state["val_loss"] = val_metrics["loss"] - elif "val_loss" in dpo_save_state: - del dpo_save_state["val_loss"] + dpo_save_state.update(val_metrics) if master_config["checkpointing"]["metric_name"] is not None: if ( @@ -540,7 +602,8 @@ def dpo_train( timing_metrics = timer.get_timing_metrics(reduction_op="sum") print("\n📊 Training Results:") - print(f" • Loss: {float(metrics['loss']):.4f}") + for metric_name in DPOValMetrics.__annotations__.keys(): + print(f" • {metric_name}: {float(metrics[metric_name]):.4f}") if "total_flops" in train_results: total_tflops = ( train_results["total_flops"] diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index 1dafc3800d..b1aa9f01be 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -13,6 +13,8 @@ # limitations under the License. import os import warnings +from collections import defaultdict +from functools import partial from pathlib import Path from typing import Optional, TypedDict @@ -31,11 +33,6 @@ preference_collate_fn, ) from nemo_rl.data.interfaces import TaskDataSpec -from nemo_rl.data.llm_message_utils import ( - add_loss_mask_to_message_log, - batched_message_log_to_flat_message, -) -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import PolicyInterface @@ -50,7 +47,6 @@ class RMSaveState(TypedDict): epoch: int # Track current epoch step: int # Track step within current epoch total_steps: int # Track total number of steps across all epochs - val_loss: float consumed_samples: int @@ -84,7 +80,7 @@ class MasterConfig(TypedDict): class RMValMetrics(TypedDict): - val_loss: float + loss: float accuracy: float rewards_chosen_mean: float rewards_rejected_mean: float @@ -98,12 +94,12 @@ def setup( master_config: MasterConfig, tokenizer: AutoTokenizer, train_dataset: AllTaskProcessedDataset, - val_dataset: AllTaskProcessedDataset, + val_dataset: dict[str, AllTaskProcessedDataset], ) -> tuple[ Policy, RayVirtualCluster, StatefulDataLoader, - StatefulDataLoader, + dict[str, StatefulDataLoader], PreferenceLoss, MasterConfig, Logger, @@ -146,7 +142,14 @@ def setup( train_dataset, batch_size=policy_config["train_global_batch_size"], shuffle=data_config["shuffle"], - collate_fn=preference_collate_fn, + collate_fn=partial( + preference_collate_fn, + tokenizer=tokenizer, + make_sequence_length_divisible_by=policy_config[ + "make_sequence_length_divisible_by" + ], + add_loss_mask=False, + ), drop_last=True, ) @@ -156,13 +159,23 @@ def setup( ) train_dataloader.load_state_dict(dataloader_state_dict) - val_dataloader = StatefulDataLoader( - val_dataset, - batch_size=rm_config["val_global_batch_size"], - shuffle=False, - collate_fn=preference_collate_fn, - drop_last=True, - ) + val_dataloader = { + k: StatefulDataLoader( + v, + batch_size=rm_config["val_global_batch_size"], + shuffle=False, + collate_fn=partial( + preference_collate_fn, + tokenizer=tokenizer, + make_sequence_length_divisible_by=policy_config[ + "make_sequence_length_divisible_by" + ], + add_loss_mask=False, + ), + drop_last=True, + ) + for k, v in val_dataset.items() + } # ========================== # Cluster @@ -220,17 +233,65 @@ def setup( # ======================================================= def validate( policy: PolicyInterface, - val_dataloader: StatefulDataLoader, + val_dataloader: dict[str, StatefulDataLoader], tokenizer, loss_fn, step: int, master_config: MasterConfig, - rm_task_spec: TaskDataSpec, val_batches: int, val_batch_size: int, val_mbs: int, + logger: Logger, ): - """Run validation on the validation dataset.""" + val_metrics, validation_timings = {}, {} + for val_dataset_name, v in val_dataloader.items(): + k_val_metrics, k_validation_timings = validate_one_dataset( + policy=policy, + val_dataloader=v, + loss_fn=loss_fn, + step=step, + master_config=master_config, + val_batches=val_batches, + val_batch_size=val_batch_size, + val_mbs=val_mbs, + dataset_name=val_dataset_name, + ) + prefix = f"validation-{val_dataset_name}" + + logger.log_metrics(k_val_metrics, step, prefix=prefix) + logger.log_metrics(k_validation_timings, step, prefix=f"timing/{prefix}") + + for metric_name in RMValMetrics.__annotations__.keys(): + if metric_name != "num_valid_samples": + val_metrics[f"{prefix}_{metric_name}"] = k_val_metrics[metric_name] + validation_timings[prefix + "_total_validation_time"] = k_validation_timings[ + "total_validation_time" + ] + + if len(validation_timings) > 0: + total_validation_time = sum(validation_timings.values()) + logger.log_metrics( + {"total_validation_time": total_validation_time}, + step, + prefix="timing/validation", + ) + validation_timings["total_validation_time"] = total_validation_time + + return val_metrics, validation_timings + + +def validate_one_dataset( + policy: PolicyInterface, + val_dataloader: StatefulDataLoader, + loss_fn, + step: int, + master_config: MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + dataset_name: str, +): + """Run validation on one validation dataset.""" if val_dataloader is None: print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -238,43 +299,17 @@ def validate( timer = Timer() with timer.time("total_validation_time"): - print(f"▶ Starting validation at step {step}...") + print(f"▶ Starting validation at step {step} for `{dataset_name}` set..") # Show a progress indicator for validation # val_total = len(val_dataloader) - list_of_val_metrics = [] - + dict_val_metrics = defaultdict(list) num_valid_batches = 0 - - policy.prepare_for_training() for batch_idx, val_batch in enumerate(val_dataloader): - ## add loss mask based on role to every message - add_loss_mask_to_message_log( - val_batch["message_log"], - roles_to_train_on=["assistant"], - ) - - cat_and_padded, input_lengths = batched_message_log_to_flat_message( - val_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - make_sequence_length_divisible_by=master_config["policy"][ - "make_sequence_length_divisible_by" - ], - ) - - val_data: BatchedDataDict = BatchedDataDict( - { - "input_ids": cat_and_padded["token_ids"], - "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": val_batch["loss_multiplier"], - } - ) - ## just run model fwd val_results = policy.train( - val_data, + val_batch, loss_fn, eval_mode=True, ## NOTE: we double the batch size here because each preference example corresponds to a pair of @@ -289,21 +324,10 @@ def validate( " This is likely because there were no valid samples." ) else: - list_of_val_metrics.append( - RMValMetrics( - val_loss=sum(val_results["all_mb_metrics"]["loss"]), - accuracy=sum(val_results["all_mb_metrics"]["accuracy"]), - rewards_chosen_mean=sum( - val_results["all_mb_metrics"]["rewards_chosen_mean"] - ), - rewards_rejected_mean=sum( - val_results["all_mb_metrics"]["rewards_rejected_mean"] - ), - num_valid_samples=sum( - val_results["all_mb_metrics"]["num_valid_samples"] - ), - ) - ) + for metric_name in RMValMetrics.__annotations__.keys(): + dict_val_metrics[metric_name] += [ + sum(val_results["all_mb_metrics"][metric_name]) + ] num_valid_batches += 1 @@ -311,39 +335,23 @@ def validate( break if num_valid_batches > 0: - sum_num_valid_samples = sum( - [m["num_valid_samples"] for m in list_of_val_metrics] - ) + sum_num_valid_samples = sum(dict_val_metrics["num_valid_samples"]) val_metrics = RMValMetrics( - val_loss=sum( - [ - m["val_loss"] * m["num_valid_samples"] - for m in list_of_val_metrics - ] - ) - / sum_num_valid_samples, - accuracy=sum( - [ - m["accuracy"] * m["num_valid_samples"] - for m in list_of_val_metrics - ] - ) - / sum_num_valid_samples, - rewards_chosen_mean=sum( - [ - m["rewards_chosen_mean"] * m["num_valid_samples"] - for m in list_of_val_metrics - ] - ) - / sum_num_valid_samples, - rewards_rejected_mean=sum( - [ - m["rewards_rejected_mean"] * m["num_valid_samples"] - for m in list_of_val_metrics - ] - ) - / sum_num_valid_samples, num_valid_samples=sum_num_valid_samples, + **{ + metric_name: sum( + [ + value * weight + for value, weight in zip( + dict_val_metrics[metric_name], + dict_val_metrics["num_valid_samples"], + ) + ] + ) + / sum_num_valid_samples + for metric_name in RMValMetrics.__annotations__.keys() + if metric_name != "num_valid_samples" + }, ) else: warnings.warn( @@ -351,11 +359,10 @@ def validate( " This is likely because there were no valid samples in the validation set." ) val_metrics = RMValMetrics( - val_loss=0.0, - accuracy=0.0, - rewards_chosen_mean=0.0, - rewards_rejected_mean=0.0, - num_valid_samples=0.0, + **{ + metric_name: 0.0 + for metric_name in RMValMetrics.__annotations__.keys() + } ) # Calculate validation metrics @@ -367,21 +374,17 @@ def validate( if num_valid_batches > 0: # Print summary of validation results - print("\n📊 Validation Results:") - print(f" • Validation loss: {val_metrics['val_loss']:.4f}") - print(f" • Validation accuracy: {val_metrics['accuracy']:.4f}") - print( - f" • Validation rewards chosen mean: {val_metrics['rewards_chosen_mean']:.4f}" - ) - print( - f" • Validation rewards rejected mean: {val_metrics['rewards_rejected_mean']:.4f}" - ) - print( - f" • Validation num valid samples: {val_metrics['num_valid_samples']:.0f}" - ) + print(f"\n📊 Validation Results for `{dataset_name}` set:") + for metric_name in RMValMetrics.__annotations__.keys(): + if metric_name != "num_valid_samples": + print(f" • Validation {metric_name}: {val_metrics[metric_name]:.4f}") + else: + print( + f" • Validation num valid samples: {val_metrics['num_valid_samples']:.0f}" + ) # Print timing information - print("\n ⏱️ Validation Timing:") + print(f"\n ⏱️ Validation Timing for `{dataset_name}` set:") validation_time = timing_metrics.get("total_validation_time", 0) print(f" • Total validation time: {validation_time:.2f}s") @@ -432,15 +435,12 @@ def rm_train( loss_fn, step=0, master_config=master_config, - rm_task_spec=rm_task_spec, val_batches=rm_config["val_batches"], val_batch_size=rm_config["val_global_batch_size"], val_mbs=rm_config["val_micro_batch_size"], + logger=logger, ) - logger.log_metrics(val_metrics, total_steps, prefix="validation") - logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") - policy.prepare_for_training() while current_epoch < max_num_epochs and ( @@ -458,35 +458,10 @@ def rm_train( with timer.time("total_step_time"): # Prepare batch and generate responses - print("▶ Preparing batch...") - with timer.time("data_processing"): - ## add loss mask based on role to every message - add_loss_mask_to_message_log( - batch["message_log"], - roles_to_train_on=["assistant"], - ) - - cat_and_padded, input_lengths = batched_message_log_to_flat_message( - batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - make_sequence_length_divisible_by=master_config["policy"][ - "make_sequence_length_divisible_by" - ], - ) - - train_data: BatchedDataDict = BatchedDataDict( - { - "input_ids": cat_and_padded["token_ids"], - "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": batch["loss_multiplier"], - } - ) - print("▶ Taking a training step...") train_results = policy.train( - train_data, + batch, loss_fn, eval_mode=False, ## NOTE: we double the batch size here because each preference example corresponds to a pair of @@ -512,16 +487,10 @@ def rm_train( loss_fn, step=total_steps + 1, master_config=master_config, - rm_task_spec=rm_task_spec, val_batches=rm_config["val_batches"], val_batch_size=rm_config["val_global_batch_size"], val_mbs=rm_config["val_micro_batch_size"], - ) - logger.log_metrics( - validation_timings, total_steps + 1, prefix="timing/validation" - ) - logger.log_metrics( - val_metrics, total_steps + 1, prefix="validation" + logger=logger, ) ## Checkpointing @@ -537,10 +506,22 @@ def rm_train( rm_save_state["step"] = (current_step + 1) % len(train_dataloader) rm_save_state["total_steps"] = total_steps + 1 rm_save_state["epoch"] = current_epoch + # Remove outdated validation metrics + for key in list(rm_save_state): + if ( + key.startswith("val") + and any( + [ + key.endswith(f"_{metric_name}") + for metric_name in RMValMetrics.__annotations__.keys() + if metric_name != "num_valid_samples" + ] + ) + and (val_metrics is None or key not in val_metrics) + ): + del rm_save_state[key] if val_metrics is not None: - rm_save_state["val_loss"] = val_metrics["val_loss"] - elif "val_loss" in rm_save_state: - del rm_save_state["val_loss"] + rm_save_state.update(val_metrics) if master_config["checkpointing"]["metric_name"] is not None: if ( @@ -590,15 +571,11 @@ def rm_train( timing_metrics = timer.get_timing_metrics(reduction_op="sum") print("\n📊 Training Results:") - print(f" • Loss: {float(metrics['loss']):.4f}") - print(f" • Accuracy: {float(metrics['accuracy']):.4f}") - print( - f" • Rewards chosen mean: {float(metrics['rewards_chosen_mean']):.4f}" - ) - print( - f" • Rewards rejected mean: {float(metrics['rewards_rejected_mean']):.4f}" - ) - print(f" • Num valid samples: {float(metrics['num_valid_samples']):.0f}") + for metric_name in RMValMetrics.__annotations__.keys(): + if metric_name != "num_valid_samples": + print(f" • {metric_name}: {float(metrics[metric_name]):.4f}") + else: + print(f" • num valid samples: {float(metrics[metric_name]):.0f}") print("\n⏱️ Timing:") # Display total time first, separately diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index ee0600bf47..e15526e736 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -31,6 +31,8 @@ class DataConfig(TypedDict): shuffle: NotRequired[bool] seed: NotRequired[int] download_dir: NotRequired[str] + train_data_path: NotRequired[str] + val_data_paths: NotRequired[dict[str, str]] class MathDataConfig(DataConfig): diff --git a/nemo_rl/data/datasets.py b/nemo_rl/data/datasets.py index d406b1309b..60dd0e091c 100644 --- a/nemo_rl/data/datasets.py +++ b/nemo_rl/data/datasets.py @@ -231,6 +231,9 @@ def eval_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]: def preference_collate_fn( data_batch: list[DPODatumSpec], + tokenizer: TokenizerType, + make_sequence_length_divisible_by: int, + add_loss_mask: bool, ) -> BatchedDataDict[Any]: """Collate function for preference data training. @@ -240,9 +243,11 @@ def preference_collate_fn( Args: data_batch: List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields. - + tokenizer: Tokenizer for text processing + make_sequence_length_divisible_by: Make the sequence length divisible by this value + add_loss_mask: Whether to add a token_mask to the returned data Returns: - BatchedDataDict with message_log, length, loss_multiplier, task_name, and idx fields. + BatchedDataDict with input_ids, input_lengths, token_mask (optional), and sample_mask fields. """ message_log = [] length = [] @@ -272,31 +277,11 @@ def preference_collate_fn( batch_max_length=batch_max_length, ) - return batch - - -def dpo_collate_fn( - data_batch: list[DPODatumSpec], - tokenizer: TokenizerType, - make_sequence_length_divisible_by: int, -) -> BatchedDataDict[Any]: - """Collate function for DPO training. - - Args: - data_batch: List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields. - tokenizer: Tokenizer for text processing - make_sequence_length_divisible_by: Make the sequence length divisible by this value - - Returns: - BatchedDataDict with input_ids, input_lengths, token_mask, and sample_mask fields. - """ - batch = preference_collate_fn(data_batch) - - ## add loss mask based on role to every message - add_loss_mask_to_message_log( - batch["message_log"], - only_unmask_final=True, - ) + if add_loss_mask: + add_loss_mask_to_message_log( + batch["message_log"], + only_unmask_final=True, + ) cat_and_padded, input_lengths = batched_message_log_to_flat_message( batch["message_log"], @@ -304,16 +289,17 @@ def dpo_collate_fn( make_sequence_length_divisible_by=make_sequence_length_divisible_by, ) - train_data: BatchedDataDict[Any] = BatchedDataDict( + data: BatchedDataDict[Any] = BatchedDataDict( { "input_ids": cat_and_padded["token_ids"], "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], "sample_mask": batch["loss_multiplier"], } ) + if add_loss_mask: + data["token_mask"] = cat_and_padded["token_loss_mask"] - return train_data + return data def assert_no_double_bos(token_ids: torch.Tensor, tokenizer: TokenizerType) -> None: diff --git a/nemo_rl/data/hf_datasets/__init__.py b/nemo_rl/data/hf_datasets/__init__.py index 76ad13e680..3c3d7c91fa 100644 --- a/nemo_rl/data/hf_datasets/__init__.py +++ b/nemo_rl/data/hf_datasets/__init__.py @@ -19,6 +19,7 @@ from nemo_rl.data.hf_datasets.oai_format_dataset import OpenAIFormatDataset from nemo_rl.data.hf_datasets.oasst import OasstDataset from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset +from nemo_rl.data.hf_datasets.preference_dataset import PreferenceDataset from nemo_rl.data.hf_datasets.prompt_response_dataset import ( PromptResponseDataset, ) @@ -31,6 +32,7 @@ "OasstDataset", "OpenAIFormatDataset", "OpenMathInstruct2Dataset", + "PreferenceDataset", "PromptResponseDataset", "SquadDataset", "Tulu3PreferenceDataset", diff --git a/nemo_rl/data/hf_datasets/dpo.py b/nemo_rl/data/hf_datasets/dpo.py index 03d5c7e872..d0c96a7e21 100644 --- a/nemo_rl/data/hf_datasets/dpo.py +++ b/nemo_rl/data/hf_datasets/dpo.py @@ -11,14 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from typing import Any + from datasets import load_dataset from nemo_rl.data.interfaces import TaskDataSpec +def to_preference_data_format(data: dict[str, Any]) -> dict[str, list[dict[str, Any]]]: + return { + "context": data["prompt"] + if isinstance(data["prompt"], list) + else [{"role": "user", "content": data["prompt"]}], + "completions": [ + { + "rank": 0, + "completion": [ + {"role": "assistant", "content": data["chosen_response"]} + ], + }, + { + "rank": 1, + "completion": [ + {"role": "assistant", "content": data["rejected_response"]} + ], + }, + ], + } + + class DPODataset: """Dataset class for Direct Preference Optimization (DPO) training. + This class is deprecated and will be removed in a future version. Use PreferenceDataset instead. + This class handles loading of preference data for DPO training. The input JSON files should contain examples with the following structure: { @@ -34,9 +61,19 @@ class DPODataset: """ def __init__(self, train_data_path: str, val_data_path: str): + warnings.warn( + "DPODataset is deprecated and will be removed in a future version. Use PreferenceDataset instead (see function `to_preference_data_format()` on how to convert your data to this new format).", + category=DeprecationWarning, + stacklevel=2, + ) + self.formatted_ds = { - "train": load_dataset("json", data_files=train_data_path, split="train"), - "validation": load_dataset("json", data_files=val_data_path, split="train"), + "train": load_dataset( + "json", data_files=train_data_path, split="train" + ).map(to_preference_data_format), + "validation": load_dataset( + "json", data_files=val_data_path, split="train" + ).map(to_preference_data_format), } self.task_spec = TaskDataSpec( diff --git a/nemo_rl/data/hf_datasets/helpsteer3.py b/nemo_rl/data/hf_datasets/helpsteer3.py index 7d694c4c06..e80fbff302 100644 --- a/nemo_rl/data/hf_datasets/helpsteer3.py +++ b/nemo_rl/data/hf_datasets/helpsteer3.py @@ -19,7 +19,11 @@ from nemo_rl.data.interfaces import TaskDataSpec -def format_helpsteer3(data: dict[str, Any]) -> dict[str, str | dict[str, str]]: +def to_preference_data_format( + data: dict[str, Any], +) -> dict[ + str, list[dict[str, int | list[dict[str, str | Any]]]] | list[dict[str, str]] +]: response_1 = data["response1"] response_2 = data["response2"] overall_preference = data["overall_preference"] @@ -40,9 +44,13 @@ def format_helpsteer3(data: dict[str, Any]) -> dict[str, str | dict[str, str]]: rejected = response_1 return { - "prompt": data["context"], - "chosen_response": chosen, - "rejected_response": rejected, + "context": [{"role": "user", "content": data["context"]}] + if isinstance(data["context"], str) + else data["context"], + "completions": [ + {"rank": 0, "completion": [{"role": "assistant", "content": chosen}]}, + {"rank": 1, "completion": [{"role": "assistant", "content": rejected}]}, + ], } @@ -51,7 +59,7 @@ class HelpSteer3Dataset: def __init__(self) -> None: ds = load_dataset("nvidia/HelpSteer3", "preference") - self.formatted_ds = ds.map(format_helpsteer3) + self.formatted_ds = ds.map(to_preference_data_format) self.task_spec = TaskDataSpec( task_name="HelpSteer3", diff --git a/nemo_rl/data/hf_datasets/preference_dataset.py b/nemo_rl/data/hf_datasets/preference_dataset.py new file mode 100644 index 0000000000..dea4d7213d --- /dev/null +++ b/nemo_rl/data/hf_datasets/preference_dataset.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import DatasetDict, load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +class PreferenceDataset: + """Preference dataset. + + This class handles loading of custom preference data. + The input JSONL file should contain valid JSON objects formatted like this: + { + "context": list of dicts, # The prompt message (including previous turns, if any) + "completions": list of dicts, # The list of completions + { + "rank": int, # The rank of the completion (lower rank is preferred) + "completion": list of dicts, # The completion message(s) + } + } + """ + + def __init__(self, dataset_path: str, split: str) -> None: + # Specifying split="train" returns Dataset instead of DatasetDict({"train": Dataset}) + self.formatted_ds = DatasetDict( + {split: load_dataset("json", data_files=dataset_path, split="train")} + ) + + self.task_spec = TaskDataSpec( + task_name="PreferenceDataset", + ) diff --git a/nemo_rl/data/hf_datasets/tulu3.py b/nemo_rl/data/hf_datasets/tulu3.py index ab3fa62623..266daea186 100644 --- a/nemo_rl/data/hf_datasets/tulu3.py +++ b/nemo_rl/data/hf_datasets/tulu3.py @@ -20,7 +20,11 @@ from nemo_rl.data.interfaces import TaskDataSpec -def format_tulu3_preference(data: dict[str, Any]) -> dict[str, str | dict[str, str]]: +def to_preference_data_format( + data: dict[str, Any], +) -> dict[ + str, list[dict[str, int | list[dict[str, str | Any]]]] | list[dict[str, str]] +]: chosen_conversation = data["chosen"] rejected_conversation = data["rejected"] @@ -46,9 +50,17 @@ def format_tulu3_preference(data: dict[str, Any]) -> dict[str, str | dict[str, s rejected_response = rejected_conversation[-1]["content"] return { - "prompt": context, - "chosen_response": chosen_response, - "rejected_response": rejected_response, + "context": context, + "completions": [ + { + "rank": 0, + "completion": [{"role": "assistant", "content": chosen_response}], + }, + { + "rank": 1, + "completion": [{"role": "assistant", "content": rejected_response}], + }, + ], } @@ -60,7 +72,7 @@ def __init__(self) -> None: path="allenai/llama-3.1-tulu-3-8b-preference-mixture", trust_remote_code=True, ) - self.formatted_ds = ds.map(format_tulu3_preference) + self.formatted_ds = ds.map(to_preference_data_format) self.task_spec = TaskDataSpec( task_name="Tulu3Preference", diff --git a/pyrefly.toml b/pyrefly.toml index e9717a1ed0..f4b3c426fb 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -60,6 +60,7 @@ project-includes = [ "nemo_rl/data/hf_datasets/oai_format_dataset.py", "nemo_rl/data/hf_datasets/oasst.py", "nemo_rl/data/hf_datasets/openmathinstruct2.py", + "nemo_rl/data/hf_datasets/preference_dataset.py", "nemo_rl/data/hf_datasets/prompt_response_dataset.py", "nemo_rl/data/hf_datasets/squad.py", "nemo_rl/data/hf_datasets/tulu3.py", diff --git a/tests/unit/data/hf_datasets/test_dpo_dataset.py b/tests/unit/data/hf_datasets/test_dpo_dataset.py index ed13df2c99..85261ff958 100644 --- a/tests/unit/data/hf_datasets/test_dpo_dataset.py +++ b/tests/unit/data/hf_datasets/test_dpo_dataset.py @@ -94,11 +94,16 @@ def test_dpo_dataset_data_format(mock_dpo_data): # Verify data format train_sample = dataset.formatted_ds["train"][0] - assert "prompt" in train_sample - assert "chosen_response" in train_sample - assert "rejected_response" in train_sample + assert "context" in train_sample + assert "completions" in train_sample # Verify data content - assert train_sample["prompt"] == "What is 2+2?" - assert train_sample["chosen_response"] == "The answer is 4." - assert train_sample["rejected_response"] == "I don't know." + print(train_sample["completions"]) + assert train_sample["context"] == [{"content": "What is 2+2?", "role": "user"}] + assert train_sample["completions"] == [ + { + "completion": [{"content": "The answer is 4.", "role": "assistant"}], + "rank": 0, + }, + {"completion": [{"content": "I don't know.", "role": "assistant"}], "rank": 1}, + ] diff --git a/tests/unit/data/hf_datasets/test_helpsteer.py b/tests/unit/data/hf_datasets/test_helpsteer.py index 036ba75669..4015a83b72 100644 --- a/tests/unit/data/hf_datasets/test_helpsteer.py +++ b/tests/unit/data/hf_datasets/test_helpsteer.py @@ -17,7 +17,7 @@ from nemo_rl.data.hf_datasets.helpsteer3 import ( HelpSteer3Dataset, - format_helpsteer3, + to_preference_data_format, ) @@ -31,8 +31,8 @@ def helpsteer3_dataset(): yield -def test_format_helpsteer3(): - """Test the format_helpsteer3 function with different preference values.""" +def test_to_preference_data_format(): + """Test the `to_preference_data_format()` function with different preference values.""" # Test case 1: response1 is preferred (overall_preference < 0) data1 = { "context": "What is 2+2?", @@ -40,10 +40,15 @@ def test_format_helpsteer3(): "response2": "I don't know.", "overall_preference": -1, } - result1 = format_helpsteer3(data1) - assert result1["prompt"] == "What is 2+2?" - assert result1["chosen_response"] == "The answer is 4." - assert result1["rejected_response"] == "I don't know." + result1 = to_preference_data_format(data1) + assert result1["context"] == [{"content": "What is 2+2?", "role": "user"}] + assert result1["completions"] == [ + { + "rank": 0, + "completion": [{"role": "assistant", "content": "The answer is 4."}], + }, + {"rank": 1, "completion": [{"role": "assistant", "content": "I don't know."}]}, + ] # Test case 2: response2 is preferred (overall_preference > 0) data2 = { @@ -52,10 +57,24 @@ def test_format_helpsteer3(): "response2": "The capital of France is Paris.", "overall_preference": 1, } - result2 = format_helpsteer3(data2) - assert result2["prompt"] == "What is the capital of France?" - assert result2["chosen_response"] == "The capital of France is Paris." - assert result2["rejected_response"] == "The capital of France is London." + result2 = to_preference_data_format(data2) + assert result2["context"] == [ + {"content": "What is the capital of France?", "role": "user"} + ] + assert result2["completions"] == [ + { + "rank": 0, + "completion": [ + {"role": "assistant", "content": "The capital of France is Paris."} + ], + }, + { + "rank": 1, + "completion": [ + {"role": "assistant", "content": "The capital of France is London."} + ], + }, + ] # Test case 3: no preference (overall_preference = 0) data3 = { @@ -64,12 +83,44 @@ def test_format_helpsteer3(): "response2": "The weather is sunny.", "overall_preference": 0, } - result3 = format_helpsteer3(data3) - assert result3["prompt"] == "What is the weather like?" + result3 = to_preference_data_format(data3) + assert result3["context"] == [ + {"content": "What is the weather like?", "role": "user"} + ] # When preference is 0, neither response is preferred, so # response 1 is used for both chosen and rejected - assert result3["chosen_response"] == "It's sunny today." - assert result3["rejected_response"] == "It's sunny today." + assert result3["completions"] == [ + { + "rank": 0, + "completion": [{"role": "assistant", "content": "It's sunny today."}], + }, + { + "rank": 1, + "completion": [{"role": "assistant", "content": "It's sunny today."}], + }, + ] + + # Test case 4: context is a list of dicts + data1 = { + "context": [ + {"role": "user", "content": "Can I ask you a question?"}, + {"role": "assistant", "content": "Sure, what do you want to know?"}, + {"role": "user", "content": "What is 2+2?"}, + ], + "response1": "4.", + "response2": "I don't know.", + "overall_preference": -1, + } + result1 = to_preference_data_format(data1) + assert result1["context"] == [ + {"role": "user", "content": "Can I ask you a question?"}, + {"role": "assistant", "content": "Sure, what do you want to know?"}, + {"role": "user", "content": "What is 2+2?"}, + ] + assert result1["completions"] == [ + {"rank": 0, "completion": [{"role": "assistant", "content": "4."}]}, + {"rank": 1, "completion": [{"role": "assistant", "content": "I don't know."}]}, + ] def test_helpsteer3_dataset_initialization(helpsteer3_dataset): @@ -96,6 +147,5 @@ def test_helpsteer3_dataset_data_format(helpsteer3_dataset): # Verify data format sample = dataset.formatted_ds["train"][0] - assert "prompt" in sample - assert "chosen_response" in sample - assert "rejected_response" in sample + assert "context" in sample + assert "completions" in sample diff --git a/tests/unit/data/hf_datasets/test_preference_dataset.py b/tests/unit/data/hf_datasets/test_preference_dataset.py new file mode 100644 index 0000000000..955a91809f --- /dev/null +++ b/tests/unit/data/hf_datasets/test_preference_dataset.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile + +import pytest + +from nemo_rl.data.hf_datasets.preference_dataset import PreferenceDataset + + +@pytest.fixture +def mock_preference_data(): + """Create temporary preference dataset files with sample data.""" + preference_data = [ + { + "context": [{"role": "user", "content": "What is 2+2?"}], + "completions": [ + { + "rank": 1, + "completion": [ + {"role": "assistant", "content": "The answer is 4."} + ], + }, + { + "rank": 2, + "completion": [{"role": "assistant", "content": "I don't know."}], + }, + ], + }, + { + "context": [{"role": "user", "content": "What is the capital of France?"}], + "completions": [ + { + "rank": 1, + "completion": [ + { + "role": "assistant", + "content": "The capital of France is Paris.", + } + ], + }, + { + "rank": 2, + "completion": [ + { + "role": "assistant", + "content": "The capital of France is London.", + } + ], + }, + ], + }, + ] + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as preference_file: + json.dump(preference_data, preference_file) + preference_path = preference_file.name + + try: + yield preference_path + finally: + # Cleanup + os.unlink(preference_path) + + +def test_preference_dataset_initialization(mock_preference_data): + """Test that PreferenceDataset initializes correctly with valid data files.""" + preference_path = mock_preference_data + + dataset = PreferenceDataset(dataset_path=preference_path, split="train") + + # Verify dataset initialization + assert dataset.task_spec.task_name == "PreferenceDataset" + + # Verify formatted_ds structure + assert "train" in dataset.formatted_ds + assert len(dataset.formatted_ds["train"]) == 2 + + +def test_preference_dataset_data_format(mock_preference_data): + """Test that PreferenceDataset correctly loads and formats the data.""" + preference_path = mock_preference_data + dataset = PreferenceDataset(dataset_path=preference_path, split="train") + + # Verify data format + sample = dataset.formatted_ds["train"][0] + assert "context" in sample + assert "completions" in sample + + # Verify context structure + assert isinstance(sample["context"], list) + assert len(sample["context"]) == 1 + assert "role" in sample["context"][0] + assert "content" in sample["context"][0] + + # Verify completions structure + assert isinstance(sample["completions"], list) + assert len(sample["completions"]) == 2 + + for completion in sample["completions"]: + assert "rank" in completion + assert "completion" in completion + assert isinstance(completion["rank"], int) + assert isinstance(completion["completion"], list) diff --git a/tests/unit/data/hf_datasets/test_tulu3.py b/tests/unit/data/hf_datasets/test_tulu3.py new file mode 100644 index 0000000000..d5ccf2d254 --- /dev/null +++ b/tests/unit/data/hf_datasets/test_tulu3.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from nemo_rl.data.hf_datasets.tulu3 import ( + Tulu3PreferenceDataset, + to_preference_data_format, +) + + +@pytest.fixture(scope="module") +def tulu3_dataset(): + try: + dataset = Tulu3PreferenceDataset() + yield dataset + except Exception as e: + print(f"Error during loading Tulu3PreferenceDataset: {e}") + yield + + +def test_to_preference_data_format(): + """Test the `to_preference_data_format()` function with different preference values.""" + data = { + "prompt": "What is 2+2?", + "chosen": [ + {"content": "What is 2+2?", "role": "user"}, + {"role": "assistant", "content": "The answer is 4."}, + ], + "rejected": [ + {"content": "What is 2+2?", "role": "user"}, + {"role": "assistant", "content": "I don't know."}, + ], + } + result = to_preference_data_format(data) + assert result["context"] == [{"content": "What is 2+2?", "role": "user"}] + assert result["completions"] == [ + { + "rank": 0, + "completion": [{"role": "assistant", "content": "The answer is 4."}], + }, + {"rank": 1, "completion": [{"role": "assistant", "content": "I don't know."}]}, + ] + + +def test_tulu3_dataset_initialization(tulu3_dataset): + """Test that Tulu3PreferenceDataset initializes correctly.""" + + dataset = tulu3_dataset + if dataset is None: + pytest.skip("dataset download is flaky") + + # Verify dataset initialization + assert dataset.task_spec.task_name == "Tulu3Preference" + + +def test_tulu3_dataset_data_format(tulu3_dataset): + """Test that Tulu3PreferenceDataset correctly formats the data.""" + + dataset = tulu3_dataset + if dataset is None: + pytest.skip("dataset download is flaky") + + assert isinstance(dataset.formatted_ds, dict) + assert "train" in dataset.formatted_ds + + # Verify data format + sample = dataset.formatted_ds["train"][0] + assert "prompt" in sample + assert "chosen" in sample + assert "rejected" in sample diff --git a/tests/unit/data/test_datasets.py b/tests/unit/data/test_datasets.py index d879b09a85..6cdd8203b5 100755 --- a/tests/unit/data/test_datasets.py +++ b/tests/unit/data/test_datasets.py @@ -16,13 +16,13 @@ import torch -from nemo_rl.data.datasets import dpo_collate_fn +from nemo_rl.data.datasets import preference_collate_fn from nemo_rl.data.interfaces import DatumSpec from nemo_rl.distributed.batched_data_dict import BatchedDataDict -def test_dpo_collate_fn(): - """Test that dpo_collate_fn correctly processes DPO training data.""" +def test_preference_collate_fn(): + """Test that preference_collate_fn correctly processes preference data.""" # Create mock tokenizer mock_tokenizer = MagicMock() mock_tokenizer.pad_token_id = 0 @@ -93,9 +93,12 @@ def test_dpo_collate_fn(): ), ] - # Call dpo_collate_fn - train_data = dpo_collate_fn( - data_batch, mock_tokenizer, make_sequence_length_divisible_by=16 + # Call preference_collate_fn + train_data = preference_collate_fn( + data_batch, + mock_tokenizer, + make_sequence_length_divisible_by=16, + add_loss_mask=True, ) # Verify the output structure