diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 7a80e5dcd7..54d81f46ee 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -170,6 +170,7 @@ jobs: if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then uv run --no-sync bash ./tests/functional/sft.sh uv run --no-sync bash ./tests/functional/grpo.sh + uv run --no-sync bash ./tests/functional/dpo.sh else echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }} fi diff --git a/README.md b/README.md index 14391c5883..eef9fc1b39 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,15 @@ - [Features](#features) - [Prerequisuites](#prerequisuites) - [Quick start](#quick-start) - - [SFT](#sft) + - [GRPO](#grpo) - [Single Node](#single-node) - [Multi-node](#multi-node) - - [GRPO](#grpo) + - [SFT](#sft) - [Single Node](#single-node-1) - [Multi-node](#multi-node-1) + - [DPO](#dpo) + - [Single Node](#single-node-2) + - [Multi-node](#multi-node-2) - [Cluster Start](#cluster-start) **Nemo-Reinforcer** is a scalable and efficient post-training library designed for models ranging from 1 GPU to thousands, and from tiny to over 100 billion parameters. @@ -33,10 +36,10 @@ What you can expect: - ✅ **Environment Support** - Support for multi-environment training. - ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning) - ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state) +- ✅ **DPO Algorithm** - Direct Preference Optimization for alignment - 🔜 **Larger Model Support** - Native PyTorch support for models up to 70B parameters - 🔜 **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training - 🔜 **Environment Isolation** - Dependency isolation between components -- 🔜 **DPO Algorithm** - Direct Preference Optimization for alignment ## Prerequisuites @@ -59,6 +62,61 @@ pip install uv **Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. +### GRPO + +We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset. + +#### Single Node + +To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`: + +```sh +# Run the GRPO math example using a 1B parameter model +uv run python examples/run_grpo_math.py +``` + +By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus, + +```sh +# Run the GRPO math example using a 1B parameter model using 8 GPUs +uv run python examples/run_grpo_math.py \ + cluster.gpus_per_node=8 +``` + +You can override any of the parameters listed in the yaml configuration file. For example, + +```sh +uv run python examples/run_grpo_math.py \ + policy.model_name="Qwen/Qwen2-1.5B" \ + checkpointing.checkpoint_dir="results/qwen1_5b_math" \ + logger.wandb_enabled=True \ + logger.wandb.name="grpo-qwen1_5b_math" \ + logger.num_val_samples_to_print=10 \ +``` + +#### Multi-node + +```sh +# Run from the root of NeMo-Reinforcer repo +NUM_ACTOR_NODES=2 +# Add a timestamp to make each job name unique +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# grpo_math_8b uses Llama-3.1-8B-Instruct model +COMMAND="uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \ +UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ +CONTAINER=YOUR_CONTAINER \ +MOUNTS="$PWD:$PWD" \ +sbatch \ + --nodes=${NUM_ACTOR_NODES} \ + --account=YOUR_ACCOUNT \ + --job-name=YOUR_JOBNAME \ + --partition=YOUR_PARTITION \ + --time=4:0:0 \ + --gres=gpu:8 \ + ray.sub +``` + ### SFT We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). @@ -87,15 +145,12 @@ Refer to `examples/configs/sft.yaml` for a full list of parameters that can be o #### Multi-node -For distributed training across multiple nodes: - ```sh # Run from the root of NeMo-Reinforcer repo NUM_ACTOR_NODES=2 # Add a timestamp to make each job name unique TIMESTAMP=$(date +%Y%m%d_%H%M%S) -# SFT experiment uses Llama-3.1-8B model COMMAND="uv run ./examples/run_sft.py --config examples/configs/sft.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/sft_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='sft-llama8b'" \ CONTAINER=YOUR_CONTAINER \ MOUNTS="$PWD:$PWD" \ @@ -109,48 +164,55 @@ sbatch \ ray.sub ``` -### GRPO +### DPO -We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset. +We provide a sample DPO experiment that uses the [HelpSteer3 dataset](https://huggingface.co/datasets/nvidia/HelpSteer3) for preference-based training. #### Single Node -To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`: +The default DPO experiment is configured to run on a single GPU. To launch the experiment: ```sh -# Run the GRPO math example using a 1B parameter model -uv run python examples/run_grpo_math.py +uv run python examples/run_dpo.py ``` -By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus, +This trains `Llama3.2-1B-Instruct` on one GPU. + +If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama3.1 Instruct model: ```sh -# Run the GRPO math example using a 1B parameter model using 8 GPUs -uv run python examples/run_grpo_math.py \ +uv run python examples/run_dpo.py \ + policy.model_name="meta-llama/Llama-3.1-8B-Instruct" \ + policy.train_global_batch_size=256 \ cluster.gpus_per_node=8 ``` -You can override any of the parameters listed in the yaml configuration file. For example, +Any of the DPO parameters can be customized from the command line. For example: ```sh -uv run python examples/run_grpo_math.py \ - policy.model_name="Qwen/Qwen2-1.5B" \ - checkpointing.checkpoint_dir="results/qwen1_5b_math" \ +uv run python examples/run_dpo.py \ + dpo.sft_loss_weight=0.1 \ + dpo.preference_average_log_probs=True \ + checkpointing.checkpoint_dir="results/llama_dpo_sft" \ logger.wandb_enabled=True \ - logger.wandb.name="grpo-qwen1_5b_math" \ - logger.num_val_samples_to_print=10 \ + logger.wandb.name="llama-dpo-sft" ``` +Refer to [dpo.yaml](examples/configs/dpo.yaml) for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md). + #### Multi-node +For distributed DPO training across multiple nodes, modify the following script for your use case: + ```sh # Run from the root of NeMo-Reinforcer repo +## number of nodes to use for your job NUM_ACTOR_NODES=2 # Add a timestamp to make each job name unique TIMESTAMP=$(date +%Y%m%d_%H%M%S) -# grpo_math_8b uses Llama-3.1-8B-Instruct model -COMMAND="uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \ +COMMAND="uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 dpo.val_global_batch_size=32 checkpointing.checkpoint_dir='results/dpo_llama81_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama1b'" \ +RAY_DEDUP_LOGS=0 \ CONTAINER=YOUR_CONTAINER \ MOUNTS="$PWD:$PWD" \ sbatch \ diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md new file mode 100644 index 0000000000..17e3bd303f --- /dev/null +++ b/docs/guides/dpo.md @@ -0,0 +1,169 @@ +# Direct Preference Optimization in Reinforcer + +[Direct Preference Optimization (DPO)](https://arxiv.org/pdf/2305.18290) is an RL-free alignment algorithm that operates on preference data. Given a prompt and a pair of chosen and rejected responses, DPO aims +to increase the probability of the chosen response and decrease the probability of the rejected response relative to a frozen reference model. The actor is initialized using the reference model. For more details, refer to the +[DPO paper](https://arxiv.org/pdf/2305.18290). + +## Launch a DPO Run + +The script [examples/run_dpo.py](../../examples/run_dpo.py) can be used to launch a DPO experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md). + +Be sure to launch the job using `uv`. The command to launch a DPO job is as follows: +```bash +uv run examples/run_dpo.py --config +``` +If not specified, `config` will default to [examples/configs/dpo.yaml](../../examples/configs/dpo.yaml). + +## Configuration + +Reinforcer allows users to configure DPO experiments using `yaml` config files. An example DPO configuration file can be found [here](../../examples/configs/dpo.yaml). + +To override a value in the config, either update the value in the `yaml` file directly, or pass the override via the command line. For example: + +```bash +uv run examples/run_dpo.py \ + cluster.gpus_per_node=8 \ + dpo.sft_loss_weight=0.1 \ + dpo.preference_average_log_probs=True \ + logger.wandb.name="dpo-dev-8-gpu" +``` + +**Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. + +## Datasets + +Each class representing a Reinforcer DPO dataset is expected to have the following attributes: +1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below. +2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. + +DPO datasets are expected to follow a specific format with three key fields: +- `prompt`: The input prompt/context +- `chosen_response`: The preferred/winning response +- `rejected_response`: The non-preferred/losing response + +[data/hf_datasets/helpsteer3.py](../../nemo_reinforcer/data/hf_datasets/helpsteer3.py) provides an example of how to format data for DPO: + +```python +def format_helpsteer3(data): + response_1 = data["response1"] + response_2 = data["response2"] + overall_preference = data["overall_preference"] + + if overall_preference < 0: + chosen = response_1 + rejected = response_2 + elif overall_preference == 0: + chosen = response_1 + rejected = response_1 + else: + chosen = response_2 + rejected = response_1 + + return { + "prompt": data["context"], + "chosen_response": chosen, + "rejected_response": rejected, + } +``` + +We also provide a [DPODataset](../../nemo_reinforcer/data/hf_datasets/dpo.py) class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys. + +## Adding Custom DPO Datasets + +Adding a new DPO dataset is straightforward. Your custom dataset class should: +1. Implement the required format conversion in the constructor +2. Set up the appropriate `task_spec` + +Here's a minimal example which simply re-keys an existing jsonl dataset: + +```{testcode} +from datasets import load_dataset +from nemo_reinforcer.data.interfaces import TaskDataSpec +from docs.helpers import make_dpo_dataset + +class CustomDPODataset: + def preprocess_dataset( + self, + data, + prompt_key: str = "context", + chosen_key: str = "chosen", + rejected_key: str = "rejected" + ): + return { + "prompt": data[prompt_key], + "chosen_response": data[chosen_key], + "rejected_response": data[rejected_key], + } + + def __init__( + self, + train_data_path: str, + val_data_path: str, + prompt_key: str, + chosen_key: str, + rejected_key: str, + ): + # Load and format your dataset + fn_kwargs={ + "prompt_key": prompt_key, + "chosen_key": chosen_key, + "rejected_key": rejected_key + } + formatted_ds = { + "train": load_dataset("json", data_files=train_data_path, split="train").map( + self.preprocess_dataset, + fn_kwargs=fn_kwargs, + ), + "validation": load_dataset("json", data_files=val_data_path, split="train").map( + self.preprocess_dataset, + fn_kwargs=fn_kwargs, + ), + } + + # Initialize task spec with dataset name + self.task_spec = TaskDataSpec( + task_name="custom_dpo", + ) + self.formatted_ds = formatted_ds + +# Create temporary files using helper function +train_file, val_file = make_dpo_dataset() + +# Initialize dataset +dataset = CustomDPODataset( + train_data_path=train_file.name, + val_data_path=val_file.name, + prompt_key="context", + chosen_key="chosen", + rejected_key="rejected" +) + +# Test dataset properties +print(f"Task name: {dataset.task_spec.task_name}") +print(f"Train examples: {len(dataset.formatted_ds['train'])}") +print(f"Validation examples: {len(dataset.formatted_ds['validation'])}") +print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}") +print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}") +print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}") +``` + +```{testoutput} +Task name: custom_dpo +Train examples: 2 +Validation examples: 2 +First train example prompt: What is 2+2? +First train example chosen response: 4 +First train example rejected response: 5 +``` + +## DPO-Specific Parameters + +The DPO implementation in Reinforcer supports several key parameters that can be adjusted: + +- `dpo.reference_policy_kl_penalty`: Controls the strength of the KL penalty term +- `dpo.preference_loss_weight`: Weight for the preference loss +- `dpo.sft_loss_weight`: Weight for the auxiliary SFT loss +- `dpo.preference_average_log_probs`: Whether to average log probabilities over tokens in the preference loss term +- `dpo.sft_average_log_probs`: Whether to average log probabilities over tokens in the SFT loss term + +These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case. diff --git a/docs/helpers.py b/docs/helpers.py new file mode 100755 index 0000000000..805d5877d1 --- /dev/null +++ b/docs/helpers.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import json + + +def make_dpo_dataset(): + train_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + val_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + + # Write train data + train_data = [ + {"context": "What is 2+2?", "chosen": "4", "rejected": "5"}, + {"context": "What is 3*3?", "chosen": "9", "rejected": "6"}, + ] + for item in train_data: + lines = train_file.write(json.dumps(item) + "\n") + train_file.flush() + + # Write validation data + val_data = [ + {"context": "What is 4+4?", "chosen": "8", "rejected": "7"}, + {"context": "What is 5*5?", "chosen": "25", "rejected": "20"}, + ] + for item in val_data: + lines = val_file.write(json.dumps(item) + "\n") + val_file.flush() + + return train_file, val_file diff --git a/docs/index.md b/docs/index.md index 0b802b0ce2..5daf42c2d7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,6 +17,7 @@ cluster.md adding-new-models.md guides/sft.md +guides/dpo.md guides/grpo.md guides/eval.md ``` diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml new file mode 100755 index 0000000000..f4b4b41c27 --- /dev/null +++ b/examples/configs/dpo.yaml @@ -0,0 +1,105 @@ +# DPO Algorithm Configuration +dpo: + max_num_epochs: 1 + max_num_steps: 150 + val_period: 25 + val_batches: 8 + val_global_batch_size: 8 + val_micro_batch_size: 1 + val_at_start: true + seed: 42 + + reference_policy_kl_penalty: 0.05 + preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss + sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss + + ## TODO(@ashors) support other loss functions + #preference_loss: dpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl + #gt_reward_scale: 1. # the scale of the rewards in RPO + preference_loss_weight: 1 # the coefficient of the preference loss + sft_loss_weight: 0 # the coefficient of the SFT loss + +checkpointing: + enabled: true + checkpoint_dir: "results/dpo" + metric_name: "val_loss" + higher_is_better: false + keep_top_k: 3 + save_period: 50 + +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer: + name: "meta-llama/Llama-3.2-1B-Instruct" + + # number of preference samples per batch + # each preference sample corresponds to a pair of chosen and rejected responses + # so the actual batch size processed by the model is train_global_batch_size * 2 + train_global_batch_size: 128 + train_micro_batch_size: 2 + + ## TODO(@ashors) support + #logprob_batch_size: ${policy.train_micro_batch_size} + max_total_sequence_length: 1024 + precision: "float32" + fsdp_offload_enabled: false + activation_checkpointing_enabled: false + + dtensor_cfg: + enabled: false + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + ## NOTE: there is a known issue with gradient clipping when using Dtensor + ## if using dtensor, set max_grad_norm to NULL + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 20 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [20] + +data: + dataset_name: "HelpSteer3" + max_input_seq_length: ${policy.max_total_sequence_length} +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "dpo-dev" + name: "dpo" + tensorboard: + log_dir: "tb_logs-dpo-dev" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/examples/run_dpo.py b/examples/run_dpo.py new file mode 100644 index 0000000000..f780933310 --- /dev/null +++ b/examples/run_dpo.py @@ -0,0 +1,269 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +import warnings +from typing import Dict, Any + +from omegaconf import OmegaConf + +from nemo_reinforcer.algorithms.dpo import MasterConfig, dpo_train, setup +from nemo_reinforcer.algorithms.utils import get_tokenizer +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides +from nemo_reinforcer.utils.logger import get_next_experiment_dir +from nemo_reinforcer.data import DataConfig, hf_datasets +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset +from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec +from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log +from transformers import AutoTokenizer +from nemo_reinforcer.models.policy import PolicyConfig + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run DPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +# ======================================================= +# Data Processing +# ======================================================= +def dpo_preprocessor( + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary for DPO training. + + Examples: + ```{doctest} + >>> from transformers import AutoTokenizer + >>> from nemo_reinforcer.data.interfaces import TaskDataSpec + >>> + >>> # Initialize tokenizer and task spec + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + >>> ## set a passthrough chat template for simplicity + >>> tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}" + >>> task_spec = TaskDataSpec(task_name="test_dpo") + >>> + >>> datum = { + ... "prompt": "What is 2+2?", + ... "chosen_response": "4", + ... "rejected_response": "5" + ... } + >>> + >>> processed = dpo_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) + >>> len(processed["message_log_chosen"]) + 2 + >>> processed["message_log_chosen"][0]["content"] + '<|begin_of_text|>What is 2+2?' + >>> processed["message_log_chosen"][-1]["content"] + '4<|eot_id|>' + >>> processed["message_log_rejected"][-1]["content"] + '5<|eot_id|>' + >>> + >>> # prompt can also be a list with multiple messages + >>> datum = { + ... "prompt": [{"role": "user", "content": "I have a question."}, {"role": "assistant", "content": "Sure!"}, {"role": "user", "content": "What is 2+2?"}], + ... "chosen_response": "4", + ... "rejected_response": "5" + ... } + >>> processed = dpo_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) + >>> len(processed["message_log_chosen"]) + 4 + >>> processed["message_log_chosen"][1]["content"] + 'Sure!' + >>> processed["message_log_chosen"][-1]["content"] + '4<|eot_id|>' + >>> processed["message_log_rejected"][-1]["content"] + '5<|eot_id|>' + + ``` + """ + if isinstance(datum_dict["prompt"], list): + messages_chosen = datum_dict["prompt"].copy() + messages_rejected = datum_dict["prompt"].copy() + else: + messages_chosen = [ + { + "role": "user", + "content": datum_dict["prompt"], + }, + ] + messages_rejected = [ + { + "role": "user", + "content": datum_dict["prompt"], + }, + ] + + messages_chosen.append( + { + "role": "assistant", + "content": datum_dict["chosen_response"], + }, + ) + + messages_rejected.append( + { + "role": "assistant", + "content": datum_dict["rejected_response"], + }, + ) + + message_log_chosen = get_formatted_message_log( + messages_chosen, tokenizer, task_data_spec + ) + message_log_rejected = get_formatted_message_log( + messages_rejected, tokenizer, task_data_spec + ) + + length_chosen = sum(len(m["token_ids"]) for m in message_log_chosen) + length_rejected = sum(len(m["token_ids"]) for m in message_log_rejected) + + loss_multiplier = 1.0 + if max(length_chosen, length_rejected) > max_seq_length: + warnings.warn( + f"Sequence length {max(length_chosen, length_rejected)} exceeds max_seq_length {max_seq_length}. Ignoring example." + ) + # make smaller and mask out + for message in message_log_chosen: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log_chosen)) + ] + for message in message_log_rejected: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log_rejected)) + ] + loss_multiplier = 0.0 + + output = { + "message_log_chosen": message_log_chosen, + "length_chosen": length_chosen, + "message_log_rejected": message_log_rejected, + "length_rejected": length_rejected, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + return output + + +def setup_data(data_config: DataConfig, policy_config: PolicyConfig): + print("\n▶ Setting up data...") + + if data_config["dataset_name"] == "HelpSteer3": + data = hf_datasets.HelpSteer3Dataset() + else: + data = hf_datasets.DPODataset( + train_data_path=data_config["train_data_path"], + val_data_path=data_config["val_data_path"], + ) + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] + + dpo_task_spec = data.task_spec + + tokenizer = get_tokenizer(policy_config["tokenizer"]) + train_dataset = AllTaskProcessedDataset( + train_dataset, + tokenizer, + dpo_task_spec, + dpo_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + dpo_task_spec, + dpo_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + return train_dataset, val_dataset, tokenizer, dpo_task_spec + + +def main(): + """Main entry point.""" + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join(os.path.dirname(__file__), "configs", "dpo.yaml") + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"📊 Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # setup data + train_dataset, val_dataset, tokenizer, dpo_task_spec = setup_data( + config["data"], config["policy"] + ) + ( + policy, + cluster, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + dpo_save_state, + master_config, + ) = setup(config, tokenizer, train_dataset, val_dataset) + dpo_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + dpo_save_state, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py new file mode 100644 index 0000000000..64d14164d5 --- /dev/null +++ b/nemo_reinforcer/algorithms/dpo.py @@ -0,0 +1,520 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from collections import defaultdict +from functools import partial +from pathlib import Path +from transformers import AutoTokenizer +from typing import Optional, Tuple, TypedDict +from tqdm import tqdm + +import numpy as np +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from nemo_reinforcer.algorithms.loss_functions import ( + DPOLossFn, +) +from nemo_reinforcer.algorithms.utils import set_seed, get_tokenizer +from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, dpo_collate_fn +from nemo_reinforcer.data.interfaces import TaskDataSpec +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_reinforcer.models.interfaces import PolicyInterface +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig +from nemo_reinforcer.utils.logger import Logger, LoggerConfig +from nemo_reinforcer.utils.timer import Timer + + +class DPOSaveState(TypedDict): + epoch: int # Track current epoch + step: int # Track step within current epoch + total_steps: int # Track total number of steps across all epochs + val_loss: float + consumed_samples: int + + +def _default_dpo_save_state() -> DPOSaveState: + return { + "epoch": 0, + "step": 0, + "total_steps": 0, + "consumed_samples": 0, + } + + +class DPOConfig(TypedDict): + max_num_epochs: int + max_num_steps: int + val_period: int + val_batches: int + val_global_batch_size: int + val_micro_batch_size: int + val_at_start: bool + seed: int + + reference_policy_kl_penalty: float + preference_average_log_probs: bool + sft_average_log_probs: bool + ## TODO(@ashors) support other loss functions + ## https://github.com/NVIDIA/reinforcer/issues/193 + # preference_loss: str + # gt_reward_scale: float + preference_loss_weight: float + sft_loss_weight: float + + +class MasterConfig(TypedDict): + policy: PolicyConfig + data: DataConfig + dpo: DPOConfig + logger: LoggerConfig + cluster: ClusterConfig + checkpointing: CheckpointingConfig + + +# ======================================================= +# Setup & Initialization +# ======================================================= +def setup( + master_config: MasterConfig, + tokenizer: AutoTokenizer, + train_dataset: AllTaskProcessedDataset, + val_dataset: AllTaskProcessedDataset, +) -> Tuple[ + HfPolicy, + RayVirtualCluster, + StatefulDataLoader, + StatefulDataLoader, + DPOLossFn, + MasterConfig, + Logger, + TaskDataSpec, + DPOSaveState, +]: + """Main entry point for running DPO algorithm. + + Returns: + Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + """ + set_seed(master_config["dpo"]["seed"]) + + # Extract individual configs for easier access + policy_config = master_config["policy"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + dpo_config = master_config["dpo"] + + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + dpo_save_state: Optional[DPOSaveState] = checkpointer.load_training_info( + last_checkpoint_path + ) + # config validation checks + if master_config["checkpointing"]["enabled"]: + assert master_config["checkpointing"]["save_period"] > 0 + assert ( + master_config["checkpointing"]["save_period"] + % master_config["dpo"]["val_period"] + == 0 + ), ( + f"Checkpointing save period {master_config['checkpointing']['save_period']} " + f"must be a multiple of validation period {master_config['dpo']['val_period']}" + f", or we won't know what metric to save!" + ) + + # ========================== + # Data + # ========================== + ## TODO(@ashors) reduce boilerplate and move reused code into utils + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=policy_config["train_global_batch_size"], + shuffle=True, + collate_fn=partial( + dpo_collate_fn, + tokenizer=tokenizer, + make_sequence_length_divisible_by=policy_config[ + "make_sequence_length_divisible_by" + ], + ), + drop_last=True, + ) + + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + train_dataloader.load_state_dict(dataloader_state_dict) + + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=dpo_config["val_global_batch_size"], + shuffle=False, + collate_fn=partial( + dpo_collate_fn, + tokenizer=tokenizer, + make_sequence_length_divisible_by=policy_config[ + "make_sequence_length_divisible_by" + ], + ), + drop_last=True, + ) + + # ========================== + # Cluster + # ========================== + print("\n▶ Setting up compute cluster...") + cluster = RayVirtualCluster( + name="dpo_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1, + ) + print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + # ========================== + # Training + # ========================== + print("\n▶ Setting up model...") + policy = HfPolicy( + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + weights_path=Path(last_checkpoint_path) / "policy" / "weights" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" + if last_checkpoint_path + else None, + init_optimizer=True, + init_reference_model=True, + ) + loss_fn = DPOLossFn(master_config["dpo"]) + print(f" ✓ Model initialized") + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n") + + return ( + policy, + cluster, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + dpo_save_state, + master_config, + ) + + +def add_ref_logprobs_to_data(dataloader, policy, master_config): + dataloader_iter = iter(dataloader) + while True: + try: + batch = next(dataloader_iter) + + ## append ref policy logprobs to batch + logprobs = policy.get_reference_policy_logprobs( + batch, + micro_batch_size=master_config["policy"]["train_micro_batch_size"] * 2, + )["reference_logprobs"] + ## want logprobs for batch to correspond to the log probabilities of the next tokens + ## so we roll the logprobs to the left by one + batch["reference_policy_logprobs"] = torch.roll(logprobs, -1, dims=-1) + + yield batch + + except StopIteration: + break + + +# ======================================================= +# Training & Validation +# ======================================================= +def validate( + policy: PolicyInterface, + val_dataloader: StatefulDataLoader, + tokenizer, + loss_fn, + step: int, + master_config: MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, +): + """Run validation on the validation dataset.""" + if val_dataloader is None: + print(" ⚠️ No validation dataloader provided, skipping validation") + return + + timer = Timer() + + with timer.time("total_validation_time"): + print(f"▶ Starting validation at step {step}...") + + val_metrics = defaultdict(lambda: 0.0) + num_valid_batches = 0 + for batch_idx, val_batch in enumerate( + add_ref_logprobs_to_data(val_dataloader, policy, master_config) + ): + ## just run model fwd + val_results = policy.train( + val_batch, + loss_fn, + eval_mode=True, + gbs=val_batch_size * 2, + mbs=val_mbs * 2, + ) + + if len(val_results["all_mb_metrics"]) == 0: + warnings.warn( + "No validation metrics were collected for this batch." + " This is likely because there were no valid samples." + ) + + else: + for k, v in val_results["all_mb_metrics"].items(): + val_metrics[k] += np.mean(v).item() + num_valid_batches += 1 + + if val_batches > 0 and batch_idx >= val_batches - 1: + break + + for k, v in val_metrics.items(): + if k == "num_valid_samples": + continue + val_metrics[k] /= num_valid_batches + + # Calculate validation metrics + policy.prepare_for_training() + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + if len(val_metrics) == 0: + warnings.warn( + "No validation metrics were collected." + " This is likely because there were no valid samples in the validation set." + ) + + else: + # Print summary of validation results + print("\n📊 Validation Results:") + print(f" • Validation loss: {float(val_metrics['loss']):.4f}") + print(f" • Validation accuracy: {float(val_metrics['accuracy']):.4f}") + + # Print timing information + print("\n ⏱️ Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s") + + # Make sure to reset the timer after validation + timer.reset() + + return val_metrics, timing_metrics + + +def dpo_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + dpo_save_state, +): + # Run dpo training + timer = Timer() + + if dpo_save_state is None: + dpo_save_state = _default_dpo_save_state() + current_epoch = 0 + current_step = 0 + total_steps = 0 + else: + current_epoch = dpo_save_state["epoch"] + current_step = dpo_save_state["step"] + total_steps = dpo_save_state["total_steps"] + + dpo_config = master_config["dpo"] + # Validation configuration + val_period = dpo_config["val_period"] + val_at_start = dpo_config["val_at_start"] + max_num_epochs = dpo_config["max_num_epochs"] + + # Run validation at the start if configured + if val_at_start and total_steps == 0: + print("\n🔍 Running initial validation...") + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=0, + master_config=master_config, + val_batches=dpo_config["val_batches"], + val_batch_size=dpo_config["val_global_batch_size"], + val_mbs=dpo_config["val_micro_batch_size"], + ) + + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") + + policy.prepare_for_training() + + while ( + current_epoch < max_num_epochs + and total_steps < master_config["dpo"]["max_num_steps"] + ): + print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") + + for batch in add_ref_logprobs_to_data(train_dataloader, policy, master_config): + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['dpo']['max_num_steps'])} {'=' * 25}" + ) + + with timer.time("total_step_time"): + print("▶ Taking a training step...") + train_results = policy.train( + batch, + loss_fn, + eval_mode=False, + ## NOTE: we double the batch size here because each preference example corresponds to a pair of + ## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch. + gbs=master_config["policy"]["train_global_batch_size"] * 2, + mbs=master_config["policy"]["train_micro_batch_size"] * 2, + ) + + # Run validation if it's a validation step + if val_period > 0 and (total_steps + 1) % val_period == 0: + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=total_steps + 1, + master_config=master_config, + val_batches=dpo_config["val_batches"], + val_batch_size=dpo_config["val_global_batch_size"], + val_mbs=dpo_config["val_micro_batch_size"], + ) + logger.log_metrics( + validation_timings, total_steps + 1, prefix="timing/validation" + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + + ## Checkpointing + dpo_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + if ( + master_config["checkpointing"]["enabled"] + and (total_steps + 1) + % master_config["checkpointing"]["save_period"] + == 0 + ): # +1 because step is 0-indexed + is_last_checkpoint = ( + min( + len(train_dataloader) * max_num_epochs, + master_config["dpo"]["max_num_steps"], + ) + - (total_steps + 1) + < master_config["checkpointing"]["save_period"] + ) + dpo_save_state["step"] = (current_step + 1) % len(train_dataloader) + dpo_save_state["total_steps"] = total_steps + 1 + dpo_save_state["epoch"] = current_epoch + dpo_save_state["val_loss"] = val_metrics["loss"] + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {total_steps + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, dpo_save_state, master_config + ) + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + save_hf=is_last_checkpoint, + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + losses = train_results["loss"] + metrics = { + "loss": train_results["loss"].numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k == "num_valid_samples": + metrics[k] = np.sum(v).item() + else: + metrics[k] = np.mean(v).item() + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print("\n📊 Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + print("\n⏱️ Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + + # Display all other timing metrics (if any) + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") + + timer.reset() + current_step += 1 + total_steps += 1 + + if total_steps >= master_config["dpo"]["max_num_steps"]: + return + + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 84e02b39a9..1914b27e98 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -660,7 +660,11 @@ def grpo_train( "reward": rewards.numpy(), } metrics.update(train_results["all_mb_metrics"]) - metrics = {k: np.mean(v).item() for k, v in metrics.items()} + for k, v in metrics.items(): + if k == "num_valid_samples": + metrics[k] = np.sum(v).item() + else: + metrics[k] = np.mean(v).item() metrics.update(gen_metrics) timing_metrics = timer.get_timing_metrics(reduction_op="sum") diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index ef5a698678..0d2e61a9ac 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -154,13 +154,20 @@ def __call__( "probs_ratio_clamped": probs_ratio_clamped, "kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0, "token_mult_prob_error": mult_prob_error, + "num_valid_samples": sample_mask.sum().item(), }, ) class NLLLoss(LossFunction): + """Negative Log Likelihood Loss function.""" + def __call__( - self, next_token_logits: torch.Tensor, data: BatchedDataDict + self, + next_token_logits: torch.Tensor, + data: BatchedDataDict, + dpo_loss: bool = False, + dpo_average_log_probs: bool = False, ) -> Tuple[torch.Tensor, dict]: # logits shape: [batch_size, seq_len, vocab_size] # Get the next token logits for each position @@ -185,16 +192,198 @@ def __call__( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) - # Only compute loss on generated tokens (not input tokens) - # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) - num_unmasked_tokens = torch.sum(mask) - if num_unmasked_tokens == 0: - # prevent division by zero - num_unmasked_tokens = torch.tensor(1) - loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens + if dpo_loss: + ## shape: [batch_size] + num_unmasked_tokens = torch.sum(mask, -1) + ## multiply by sample_mask to zero out invalid samples + loss = -torch.sum(token_logprobs * mask, dim=-1) + if dpo_average_log_probs: + loss = loss / num_unmasked_tokens.clamp(min=1) + else: + ## single scalar loss + # Only compute loss on generated tokens (not input tokens) + # by applying the token_loss_mask + num_unmasked_tokens = torch.sum(mask) + if num_unmasked_tokens == 0: + # prevent division by zero + num_unmasked_tokens = torch.tensor(1) + loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens + num_unmasked_tokens = num_unmasked_tokens.item() return loss, { - "loss": loss.item(), - "num_unmasked_tokens": num_unmasked_tokens.item(), + "loss": loss.item() if loss.ndim == 0 else loss, + "num_unmasked_tokens": num_unmasked_tokens, "total_tokens": mask.numel(), + "num_valid_samples": sample_mask.sum().item(), + } + + +class DPOLossConfig(TypedDict): + reference_policy_kl_penalty: float + preference_loss_weight: float = 1.0 + sft_loss_weight: float = 0.0 + preference_average_log_probs: bool = False + sft_average_log_probs: bool = False + + +class DPOLossDataDict(TypedDict): + """Required keys for the Clipped Policy Gradient loss function.""" + + input_ids: torch.Tensor + reference_policy_logprobs: torch.Tensor + token_mask: torch.Tensor + sample_mask: torch.Tensor + + +class DPOLossFn(LossFunction): + """Direct Preference Optimization (DPO) loss function. + + This loss function implements the DPO algorithm as described in: + "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" + (https://arxiv.org/abs/2305.18290) + + The loss combines two main components: + 1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones + 2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses + + The total loss is computed as: + L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ) + + where: + - w_p is the preference_loss_weight + - w_s is the sft_loss_weight + - L_pref(θ) is the preference loss term + - L_sft(θ) is the supervised fine-tuning loss term + + The preference loss term is computed as: + L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))] + + where: + - σ is the sigmoid function + - β is the reference_policy_kl_penalty + - r_chosen and r_rejected are the rewards for chosen and rejected responses + - The rewards are computed as the sum of log probability differences between + the current policy and reference policy + + If preference_average_log_probs is True, the rewards are averaged over tokens: + r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t)) + + Otherwise, the rewards are summed over tokens. + + The SFT loss term is a standard negative log likelihood loss on the chosen responses. + If sft_average_log_probs is True, the loss is averaged over tokens. + + Args: + cfg (DPOLossConfig): Configuration dictionary containing: + - reference_policy_kl_penalty (float): Strength of the KL penalty term (β) + - preference_loss_weight (float): Weight for the preference loss term (w_p) + - sft_loss_weight (float): Weight for the SFT loss term (w_s) + - preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss + - sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss + + Returns: + Tuple[torch.Tensor, dict]: A tuple containing: + - The total loss value + - A dictionary with metrics including: + - loss: Total loss value + - sft_loss: SFT loss component + - preference_loss: Preference loss component + - accuracy: Fraction of examples where chosen response has higher reward + """ + + def __init__(self, cfg: DPOLossConfig): + self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] + self.preference_loss_weight = cfg["preference_loss_weight"] + self.sft_loss_weight = cfg["sft_loss_weight"] + self.preference_average_log_probs = cfg["preference_average_log_probs"] + self.sft_average_log_probs = cfg["sft_average_log_probs"] + self.sft_loss = NLLLoss() + + def split_output_tensor(self, tensor: torch.Tensor): + return tensor[::2], tensor[1::2] + + def preference_loss( + self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] + ) -> torch.Tensor: + ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor + token_mask = data["token_mask"][:, 1:] + sample_mask = data["sample_mask"] + + next_token_logits = next_token_logits.to(torch.float32) + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"] + ) + else: + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + ref_logprobs = data["reference_policy_logprobs"][:, :-1] + + diff = (token_logprobs - ref_logprobs) * token_mask + + rewards = diff.sum(-1) + if self.preference_average_log_probs: + rewards = rewards / token_mask.sum(-1).clamp(min=1) + + rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) + rewards_delta = rewards_chosen - rewards_rejected + + per_sample_loss = ( + -torch.nn.functional.logsigmoid( + self.reference_policy_kl_penalty * rewards_delta + ) + * sample_mask[::2] + ) ## zero out invalid samples + + return ( + masked_mean(per_sample_loss, sample_mask[::2]), + (rewards_chosen > rewards_rejected).float().mean(0), + masked_mean(rewards_chosen, sample_mask[::2]), + masked_mean(rewards_rejected, sample_mask[1::2]), + ) + + def __call__( + self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] + ) -> Tuple[torch.Tensor, dict]: + sft_loss_chosen = torch.tensor(0.0) + if self.sft_loss_weight > 0: + sft_loss, _ = self.sft_loss( + next_token_logits, + data, + dpo_loss=True, + dpo_average_log_probs=self.sft_average_log_probs, + ) + sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) + sft_loss_chosen = masked_mean(sft_loss_chosen, data["sample_mask"][::2]) + + ( + preference_loss, + accuracy, + rewards_chosen_mean, + rewards_rejected_mean, + ) = self.preference_loss(next_token_logits, data) + + dpo_loss = ( + self.sft_loss_weight * sft_loss_chosen + + self.preference_loss_weight * preference_loss + ) + + ## divide by 2 because we're summing over (chosen, rejected) pairs + num_valid_samples = data["sample_mask"].sum() / 2 + + return dpo_loss, { + "loss": dpo_loss.item(), + "sft_loss": sft_loss_chosen.item(), + "preference_loss": preference_loss.item(), + "accuracy": accuracy.item(), + "rewards_chosen_mean": rewards_chosen_mean.item(), + "rewards_rejected_mean": rewards_rejected_mean.item(), + "num_valid_samples": num_valid_samples.item(), } diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index e6a6b3f418..f0dc551026 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from transformers import AutoTokenizer from pathlib import Path from typing import Optional, Tuple, TypedDict @@ -236,6 +237,7 @@ def validate( # val_total = len(val_dataloader) val_metrics = {"val_loss": 0.0} + num_valid_batches = 0 policy.prepare_for_training() for batch_idx, val_batch in enumerate(val_dataloader): @@ -270,12 +272,26 @@ def validate( gbs=val_batch_size, mbs=val_mbs, ) - val_metrics["val_loss"] += float(val_results["loss"]) - if val_batches > 0 and batch_idx >= val_batches: + if len(val_results["all_mb_metrics"]) == 0: + warnings.warn( + "No validation metrics were collected for this batch." + " This is likely because there were no valid samples." + ) + else: + val_metrics["val_loss"] += float(val_results["loss"]) + num_valid_batches += 1 + + if val_batches > 0 and batch_idx >= val_batches - 1: break - val_metrics["val_loss"] /= val_batches + if num_valid_batches > 0: + val_metrics["val_loss"] /= num_valid_batches + else: + warnings.warn( + "No validation metrics were collected." + " This is likely because there were no valid samples in the validation set." + ) # Calculate validation metrics policy.prepare_for_training() @@ -284,14 +300,15 @@ def validate( timing_metrics = timer.get_timing_metrics(reduction_op="sum") validation_time = timing_metrics.get("total_validation_time", 0) - # Print summary of validation results - print("\n📊 Validation Results:") - print(f" • Validation loss: {val_metrics['val_loss']:.4f}") + if num_valid_batches > 0: + # Print summary of validation results + print("\n📊 Validation Results:") + print(f" • Validation loss: {val_metrics['val_loss']:.4f}") - # Print timing information - print("\n ⏱️ Validation Timing:") - validation_time = timing_metrics.get("total_validation_time", 0) - print(f" • Total validation time: {validation_time:.2f}s") + # Print timing information + print("\n ⏱️ Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s") # Make sure to reset the timer after validation timer.reset() @@ -442,7 +459,11 @@ def sft_train( "loss": train_results["loss"].numpy(), } metrics.update(train_results["all_mb_metrics"]) - metrics = {k: np.mean(v).item() for k, v in metrics.items()} + for k, v in metrics.items(): + if k == "num_valid_samples": + metrics[k] = np.sum(v).item() + else: + metrics[k] = np.mean(v).item() timing_metrics = timer.get_timing_metrics(reduction_op="sum") print("\n📊 Training Results:") diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 8a81c85fb2..8d8ca78371 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -21,6 +21,10 @@ TaskDataProcessFnCallable, DatumSpec, ) +from nemo_reinforcer.data.llm_message_utils import ( + add_loss_mask_to_message_log, + batched_message_log_to_flat_message, +) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict @@ -181,3 +185,64 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: idx=idx, ) return output + + +def dpo_collate_fn( + data_batch: List[DatumSpec], tokenizer, make_sequence_length_divisible_by: int +) -> BatchedDataDict: + """Collate function for DPO training. + + This function separates the chosen and rejected responses to create + two examples per prompt. The chosen and rejected examples are interleaved + along the batch dimension, resulting in a batch size of 2 * len(data_batch). + """ + message_log = [] + length = [] + loss_multiplier = [] + idx = [] + task_names = [] + for datum_spec in data_batch: + ## interleave chosen and rejected examples + message_log.append(datum_spec["message_log_chosen"]) + message_log.append(datum_spec["message_log_rejected"]) + length.append(datum_spec["length_chosen"]) + length.append(datum_spec["length_rejected"]) + loss_multiplier.extend([datum_spec["loss_multiplier"]] * 2) + idx.extend([datum_spec["idx"]] * 2) + task_names.extend([datum_spec.get("task_name", None)] * 2) + length = torch.tensor(length) + loss_multiplier = torch.tensor(loss_multiplier) + + batch_max_length = torch.ones_like(length) * length.max() + + batch = BatchedDataDict( + message_log=message_log, + length=length, + loss_multiplier=loss_multiplier, + task_name=task_names, + idx=idx, + batch_max_length=batch_max_length, + ) + + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + batch["message_log"], + only_unmask_final=True, + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=make_sequence_length_divisible_by, + ) + + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": loss_multiplier, + } + ) + + return train_data diff --git a/nemo_reinforcer/data/hf_datasets/__init__.py b/nemo_reinforcer/data/hf_datasets/__init__.py index 919f1a494e..c6e7c8c75c 100644 --- a/nemo_reinforcer/data/hf_datasets/__init__.py +++ b/nemo_reinforcer/data/hf_datasets/__init__.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES +from nemo_reinforcer.data.hf_datasets.dpo import DPODataset +from nemo_reinforcer.data.hf_datasets.helpsteer3 import HelpSteer3Dataset +from nemo_reinforcer.data.hf_datasets.oasst import OasstDataset from nemo_reinforcer.data.hf_datasets.prompt_response_dataset import ( PromptResponseDataset, ) -from nemo_reinforcer.data.hf_datasets.oasst import OasstDataset from nemo_reinforcer.data.hf_datasets.squad import SquadDataset -from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES __all__ = [ + "DPODataset", + "HelpSteer3Dataset", "OasstDataset", "PromptResponseDataset", "SquadDataset", diff --git a/nemo_reinforcer/data/hf_datasets/dpo.py b/nemo_reinforcer/data/hf_datasets/dpo.py new file mode 100644 index 0000000000..f0cb022498 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/dpo.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datasets import load_dataset + +from nemo_reinforcer.data.interfaces import TaskDataSpec + + +class DPODataset: + """Dataset class for Direct Preference Optimization (DPO) training. + + This class handles loading of preference data for DPO training. + The input JSON files should contain examples with the following structure: + { + "prompt": str, # The input prompt/context + "chosen_response": str, # The preferred/winning response + "rejected_response": str # The non-preferred/losing response + } + + Args: + train_data_path (str): Path to the JSON file containing training data + val_data_path (str): Path to the JSON file containing validation data + + """ + + def __init__(self, train_data_path: str, val_data_path: str): + self.formatted_ds = { + "train": load_dataset("json", data_files=train_data_path, split="train"), + "validation": load_dataset("json", data_files=val_data_path, split="train"), + } + + self.task_spec = TaskDataSpec( + task_name="DPO", + ) diff --git a/nemo_reinforcer/data/hf_datasets/helpsteer3.py b/nemo_reinforcer/data/hf_datasets/helpsteer3.py new file mode 100644 index 0000000000..0ad0263c30 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/helpsteer3.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datasets import load_dataset +from absl import logging + +from nemo_reinforcer.data.interfaces import TaskDataSpec + + +def format_helpsteer3(data): + response_1 = data["response1"] + response_2 = data["response2"] + overall_preference = data["overall_preference"] + + if overall_preference < 0: + chosen = response_1 + rejected = response_2 + elif overall_preference == 0: + logging.log_every_n( + logging.WARNING, + "Preference is 0 for some examples! Setting chosen and rejected to response 1 since we don't know which response is better", + 1000, + ) + chosen = response_1 + rejected = response_1 + else: + chosen = response_2 + rejected = response_1 + + return { + "prompt": data["context"], + "chosen_response": chosen, + "rejected_response": rejected, + } + + +class HelpSteer3Dataset: + """HelpSteer3 preference dataset for DPO training.""" + + def __init__(self): + ds = load_dataset("nvidia/HelpSteer3", "preference") + self.formatted_ds = ds.map(format_helpsteer3) + + self.task_spec = TaskDataSpec( + task_name="HelpSteer3", + ) diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index e4fb959f34..362183d978 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -109,22 +109,34 @@ def get_keys_from_message_log( def add_loss_mask_to_message_log( message_log: LLMMessageLogType, roles_to_train_on: List[str] = ["assistant"], + only_unmask_final: bool = False, ) -> None: """Add token-level loss masks to each message in a message log. Args: message_log (LLMMessageLogType): List of message dictionaries containing token IDs and metadata roles_to_train_on (List[str]): List of strings indicating which speakers to unmask. Default: ["assistant"] + only_unmask_final (bool): If True, only unmask the final message in the log. Default: False """ for i, role in enumerate(roles_to_train_on): roles_to_train_on[i] = role.lower() for message in message_log: - for sentence in message: - if sentence["role"] in roles_to_train_on: - sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + for i, sentence in enumerate(message): + if only_unmask_final: + if i == len(message) - 1: + sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + else: + sentence["token_loss_mask"] = torch.zeros_like( + sentence["token_ids"] + ) else: - sentence["token_loss_mask"] = torch.zeros_like(sentence["token_ids"]) + if sentence["role"] in roles_to_train_on: + sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + else: + sentence["token_loss_mask"] = torch.zeros_like( + sentence["token_ids"] + ) def _pad_tensor( diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index a7c7f717fb..2c4bd78efd 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -15,7 +15,7 @@ import gc from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, Optional import ray @@ -255,119 +255,134 @@ def train( local_gbs = gbs // self.dp_size dataset_size = data.get("input_ids").shape[0] - # Ensure model is in training mode - self.model.train() - - # Get data from batch and move to device - data.to("cuda") - - losses = [] - all_mb_metrics = [] - for gb_start in range(0, dataset_size, local_gbs): - self.optimizer.zero_grad() - mb_losses = [] + if eval_mode: + ctx = torch.no_grad() + self.model.eval() + else: + ctx = nullcontext() + # Ensure model is in training mode + self.model.train() - # Calculate number of microbatches to process - # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size - # so its safe to not check for the case where the last data slice is smaller than mbs - num_microbatches = min(local_gbs, dataset_size - gb_start) // mbs + with ctx: + # Get data from batch and move to device + data.to("cuda") - for mb in data.slice( - gb_start, gb_start + local_gbs - ).make_microbatch_iterator(mbs): - input_ids = mb.get("input_ids").cuda() + losses = [] + all_mb_metrics = [] + for gb_start in range(0, dataset_size, local_gbs): + self.optimizer.zero_grad() + mb_losses = [] - input_lengths = mb.get("input_lengths") - batch_size, seq_len = input_ids.shape + # Calculate number of microbatches to process + # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size + # so its safe to not check for the case where the last data slice is smaller than mbs + num_microbatches = min(local_gbs, dataset_size - gb_start) // mbs - attention_mask = torch.zeros( - (batch_size, seq_len), dtype=torch.long, device=input_ids.device - ) - for i, length in enumerate(input_lengths): - # For right-padded sequence, set 1s at the beginning of the sequence - attention_mask[i, :length] = 1 + for mb in data.slice( + gb_start, gb_start + local_gbs + ).make_microbatch_iterator(mbs): + input_ids = mb.get("input_ids").cuda() - with torch.autocast(device_type="cuda", dtype=self.dtype): + input_lengths = mb.get("input_lengths") batch_size, seq_len = input_ids.shape - attention_mask_input_all_ones = torch.ones( + attention_mask = torch.zeros( (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) - position_ids = torch.arange( - seq_len, device=input_ids.device - ).repeat(batch_size, 1) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + attention_mask[i, :length] = 1 - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask_input_all_ones, - position_ids=position_ids, - use_cache=False, - ) + with torch.autocast(device_type="cuda", dtype=self.dtype): + batch_size, seq_len = input_ids.shape - # Get logprobs - if not hasattr(outputs, "logits"): - logits = self.model.lm_head(outputs.last_hidden_state) - else: - logits = outputs.logits - - loss, loss_metrics = loss_fn(logits, mb) - loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] - # Backward pass + attention_mask_input_all_ones = torch.ones( + (batch_size, seq_len), + dtype=torch.long, + device=input_ids.device, + ) + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask_input_all_ones, + position_ids=position_ids, + use_cache=False, + ) - # Loss is accumulated across microbatches so we need to scale by the number of microbatches - loss = loss / num_microbatches + # Get logprobs + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + + loss, loss_metrics = loss_fn(logits, mb) + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + # Backward pass + + # Loss is accumulated across microbatches so we need to scale by the number of microbatches + loss = loss / num_microbatches + if not eval_mode: + ## NOTE: invalid samples should be multiplied + ## by zero in the loss function to prevent them + ## from affecting the gradien + loss.backward() + if num_valid_samples > 0: + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + grad_norm = None if not eval_mode: - loss.backward() - mb_losses.append(loss.item()) - all_mb_metrics.append(loss_metrics) - - grad_norm = None - if not eval_mode: - with torch.no_grad(): - grad_norm = get_grad_norm( - self.model.parameters(), - dp_group=self.dp_mesh.get_group(), - tp_group=self.tp_mesh.get_group(), - dtype=torch.float32, - ) - if self.max_grad_norm is not None: - clip_grad_by_total_norm_( + with torch.no_grad(): + grad_norm = get_grad_norm( self.model.parameters(), - max_grad_norm=self.max_grad_norm, - total_norm=grad_norm, + dp_group=self.dp_mesh.get_group(), + tp_group=self.tp_mesh.get_group(), dtype=torch.float32, ) + if self.max_grad_norm is not None: + clip_grad_by_total_norm_( + self.model.parameters(), + max_grad_norm=self.max_grad_norm, + total_norm=grad_norm, + dtype=torch.float32, + ) - # Update parameters - self.optimizer.step() - self.scheduler.step() - - losses.append(torch.tensor(mb_losses).sum().item()) - - # Compute global loss across all ranks - with torch.no_grad(): - local_loss = torch.tensor(losses, device="cuda") - global_loss = torch.zeros_like(local_loss) - torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group()) - global_loss = local_loss / self.dp_size - - # Aggregate metrics across all microbatches - mb_metrics = defaultdict(list) - for m in all_mb_metrics: - for k, v in m.items(): - mb_metrics[k].append(v) - - metrics = { - "global_loss": global_loss.cpu(), - "local_loss": local_loss.cpu(), - "grad_norm": grad_norm, - "rank": torch.distributed.get_rank(), - "all_mb_metrics": dict(mb_metrics), - } - - return metrics - - def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + # Update parameters + self.optimizer.step() + self.scheduler.step() + + losses.append(torch.tensor(mb_losses).sum().item()) + + # Compute global loss across all ranks + with torch.no_grad(): + local_loss = torch.tensor(losses, device="cuda") + global_loss = torch.zeros_like(local_loss) + torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group()) + global_loss = local_loss / self.dp_size + + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "local_loss": local_loss.cpu(), + "grad_norm": grad_norm, + "rank": torch.distributed.get_rank(), + "all_mb_metrics": dict(mb_metrics), + } + + return metrics + + def get_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: """Get the logprobs of the model for a batch of data. Uses the configured logprob_batch_size to do microbatching. @@ -380,7 +395,11 @@ def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ - logprob_batch_size = self.cfg["logprob_batch_size"] + logprob_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) all_log_probs = [] self.model.eval() @@ -495,7 +514,9 @@ def use_reference_model(self): val = to_local_if_dtensor(v) val.copy_(curr_buffers[k]) - def get_reference_policy_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + def get_reference_policy_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: """Get the logprobs from the reference policy for a batch of data. Returns: @@ -504,7 +525,7 @@ def get_reference_policy_logprobs(self, data: BatchedDataDict) -> BatchedDataDic The logprob of input token i is specified at position i in the output logprobs tensor. """ with self.use_reference_model(): - reference_logprobs = self.get_logprobs(data) + reference_logprobs = self.get_logprobs(data, micro_batch_size) return_data = BatchedDataDict() return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index c06d738929..192d51ce88 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -15,7 +15,7 @@ import gc import warnings from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, Optional import ray @@ -225,99 +225,111 @@ def train( local_gbs = gbs // torch.distributed.get_world_size() dataset_size = data.get("input_ids").shape[0] - # Ensure model is in training mode - self.model.train() - - # Get data from batch and move to device - data.to("cuda") - - losses = [] - all_mb_metrics = [] - for gb_start in range(0, dataset_size, local_gbs): - self.optimizer.zero_grad() - mb_losses = [] - - # Calculate number of microbatches to process - # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size - # so its safe to not check for the case where the last data slice is smaller than mbs - num_microbatches = min(local_gbs, dataset_size - gb_start) // mbs - - for mb in data.slice( - gb_start, gb_start + local_gbs - ).make_microbatch_iterator(mbs): - input_ids = mb.get("input_ids") + if eval_mode: + ctx = torch.no_grad() + self.model.eval() + else: + ctx = nullcontext() + # Ensure model is in training mode + self.model.train() - input_lengths = mb.get("input_lengths") - batch_size, seq_len = input_ids.shape - attention_mask = torch.ones( - (batch_size, seq_len), dtype=torch.long, device=input_ids.device - ) - for i, length in enumerate(input_lengths): - # For right-padded sequence, set 1s at the beginning of the sequence - attention_mask[i, :length] = 1 + with ctx: + # Get data from batch and move to device + data.to("cuda") - with torch.autocast(device_type="cuda", dtype=self.dtype): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - use_cache=False, + losses = [] + all_mb_metrics = [] + for gb_start in range(0, dataset_size, local_gbs): + self.optimizer.zero_grad() + mb_losses = [] + + # Calculate number of microbatches to process + # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size + # so its safe to not check for the case where the last data slice is smaller than mbs + num_microbatches = min(local_gbs, dataset_size - gb_start) // mbs + + for mb in data.slice( + gb_start, gb_start + local_gbs + ).make_microbatch_iterator(mbs): + input_ids = mb.get("input_ids") + + input_lengths = mb.get("input_lengths") + batch_size, seq_len = input_ids.shape + attention_mask = torch.ones( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) - # Get logprobs - if not hasattr(outputs, "logits"): - logits = self.model.lm_head(outputs.last_hidden_state) - else: - logits = outputs.logits - - loss, loss_metrics = loss_fn(logits, mb) - loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] - - # Backward pass - - # Loss is accumulated across microbatches, so we need to scale by the number of microbatches - loss = loss / num_microbatches + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + attention_mask[i, :length] = 1 + + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + # Get logprobs + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + + loss, loss_metrics = loss_fn(logits, mb) + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + + # Backward pass + + # Loss is accumulated across microbatches, so we need to scale by the number of microbatches + loss = loss / num_microbatches + if not eval_mode: + ## NOTE: invalid samples should be multiplied + ## by zero in the loss function to prevent them + ## from affecting the gradient + loss.backward() + if num_valid_samples > 0: + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + # Clip gradients if not eval_mode: - loss.backward() - mb_losses.append(loss.item()) - all_mb_metrics.append(loss_metrics) - - # Clip gradients - if not eval_mode: - if self.cfg["max_grad_norm"] is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=self.cfg["max_grad_norm"] ) - # Update parameters - self.optimizer.step() - self.scheduler.step() - losses.append(torch.tensor(mb_losses).sum().item()) - - # Compute global loss across all ranks - with torch.no_grad(): - local_loss = torch.tensor(losses, device="cuda") - global_loss = torch.zeros_like(local_loss) - torch.distributed.all_reduce(local_loss) - global_loss = local_loss / torch.distributed.get_world_size() - - # Aggregate metrics across all microbatches - mb_metrics = defaultdict(list) - for m in all_mb_metrics: - for k, v in m.items(): - mb_metrics[k].append(v) - - metrics = { - "global_loss": global_loss.cpu(), - "local_loss": local_loss.cpu(), - "rank": torch.distributed.get_rank(), - "all_mb_metrics": dict(mb_metrics), - } - - return metrics - - def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + # Update parameters + self.optimizer.step() + self.scheduler.step() + losses.append(torch.tensor(mb_losses).sum().item()) + + # Compute global loss across all ranks + with torch.no_grad(): + local_loss = torch.tensor(losses, device="cuda") + global_loss = torch.zeros_like(local_loss) + torch.distributed.all_reduce(local_loss) + global_loss = local_loss / torch.distributed.get_world_size() + + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "local_loss": local_loss.cpu(), + "rank": torch.distributed.get_rank(), + "all_mb_metrics": dict(mb_metrics), + } + + return metrics + + def get_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: """Get the logprobs of the model for a batch of data. - Uses the configured logprob_batch_size to do microbatching. + If no micro-batch size is provided, uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. @@ -327,7 +339,11 @@ def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ - logprob_batch_size = self.cfg["logprob_batch_size"] + logprob_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) all_log_probs = [] self.model.eval() @@ -421,7 +437,9 @@ def use_reference_model(self): gc.collect() torch.cuda.empty_cache() - def get_reference_policy_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + def get_reference_policy_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: """Get the logprobs from the reference policy for a batch of data. Returns: @@ -430,7 +448,7 @@ def get_reference_policy_logprobs(self, data: BatchedDataDict) -> BatchedDataDic The logprob of input token i is specified at position i in the output logprobs tensor. """ with self.use_reference_model(): - reference_logprobs = self.get_logprobs(data) + reference_logprobs = self.get_logprobs(data, micro_batch_size) return_data = BatchedDataDict() return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index e4fea94363..a068c27794 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -129,7 +129,7 @@ def get_logprobs( return logprobs def get_reference_policy_logprobs( - self, data: BatchedDataDict[GenerationDatumSpec] + self, data: BatchedDataDict[GenerationDatumSpec], micro_batch_size: int = None ) -> BatchedDataDict: """Get the logprobs of the reference policy for a data dict. @@ -137,7 +137,10 @@ def get_reference_policy_logprobs( """ sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) futures = self.worker_group.run_all_workers_multiple_data( - "get_reference_policy_logprobs", sharded_data, only_on="all_tied_workers" + "get_reference_policy_logprobs", + sharded_data, + common_kwargs={"micro_batch_size": micro_batch_size}, + only_on="all_tied_workers", ) logprobs = BatchedDataDict.from_batches( self.worker_group.get_all_worker_results(futures) @@ -153,11 +156,11 @@ def train( mbs: Optional[int] = None, ): """Train the policy on a batch of data with a given loss function.""" + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] # Shard and replicate the batch shards = self.dp_size - sharded_data = data.shard_by_batch_size( - shards, batch_size=self.cfg["train_global_batch_size"] - ) + sharded_data = data.shard_by_batch_size(shards, batch_size=batch_size) # Train each shard in parallel futures = self.worker_group.run_all_workers_multiple_data( @@ -166,8 +169,8 @@ def train( common_kwargs={ "loss_fn": loss_fn, "eval_mode": eval_mode, - "gbs": gbs, - "mbs": mbs, + "gbs": batch_size, + "mbs": micro_batch_size, }, only_on="all_tied_workers", ) diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh new file mode 100755 index 0000000000..1431f17e61 --- /dev/null +++ b/tests/functional/dpo.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +LOG_DIR=$SCRIPT_DIR/$(basename $0 .sh)-logs +JSON_METRICS=$LOG_DIR/$(basename $0 .sh).json +RUN_LOG=$LOG_DIR/$(basename $0 .sh).log +export RAY_DEDUP_LOGS=0 +export UV_CACHE_DIR=${UV_CACHE_DIR:-$PROJECT_ROOT/uv_cache} +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $LOG_DIR +mkdir -p $LOG_DIR + +cd $PROJECT_ROOT +python -u $PROJECT_ROOT/examples/run_dpo.py \ + cluster.gpus_per_node=2 \ + dpo.max_num_steps=3 \ + dpo.val_batches=1 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +cd $SCRIPT_DIR +python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +python check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["2"] < 0.694' \ + diff --git a/tests/unit/algorithms/test_dpo.py b/tests/unit/algorithms/test_dpo.py new file mode 100644 index 0000000000..fa924a745d --- /dev/null +++ b/tests/unit/algorithms/test_dpo.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from unittest.mock import MagicMock, patch + +from nemo_reinforcer.algorithms.dpo import add_ref_logprobs_to_data + + +class MockPolicy: + def __init__(self, logprobs): + self.logprobs = logprobs + + def get_reference_policy_logprobs(self, batch, micro_batch_size): + return {"reference_logprobs": self.logprobs} + + +def test_add_logprobs_to_batch(): + """Test that add_ref_logprobs_to_data correctly adds reference policy logprobs to batches.""" + # Create mock data + batch_size = 2 + seq_len = 4 + vocab_size = 16 + + # Create a mock batch + mock_batch = { + "input_ids": torch.randint(0, vocab_size, (batch_size, seq_len)), + "attention_mask": torch.ones(batch_size, seq_len), + } + + # Create mock logprobs that will be returned by the policy + mock_logprobs = torch.randn(batch_size, seq_len) + + # Create a mock dataloader that yields our mock batch + mock_dataloader = MagicMock() + mock_dataloader.__iter__.return_value = iter([mock_batch]) + + # Create a mock policy that returns our mock logprobs + mock_policy = MockPolicy(mock_logprobs) + + # Create a mock master config + mock_master_config = {"policy": {"train_micro_batch_size": 1}} + + # Get the augmented batches + augmented_batches = list( + add_ref_logprobs_to_data(mock_dataloader, mock_policy, mock_master_config) + ) + + # Verify we got exactly one batch + assert len(augmented_batches) == 1 + augmented_batch = augmented_batches[0] + + # Verify the original batch data is preserved + assert torch.equal(augmented_batch["input_ids"], mock_batch["input_ids"]) + assert torch.equal(augmented_batch["attention_mask"], mock_batch["attention_mask"]) + + # Verify the reference policy logprobs were added correctly + assert "reference_policy_logprobs" in augmented_batch + assert augmented_batch["reference_policy_logprobs"].shape == (batch_size, seq_len) + + # Verify the logprobs were rolled by -1 as expected + expected_logprobs = torch.roll(mock_logprobs, -1, dims=-1) + assert torch.equal(augmented_batch["reference_policy_logprobs"], expected_logprobs) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 447bd20a54..abefb2b251 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -15,7 +15,11 @@ import torch import numpy as np -from nemo_reinforcer.algorithms.loss_functions import NLLLoss, ClippedPGLossFn +from nemo_reinforcer.algorithms.loss_functions import ( + NLLLoss, + ClippedPGLossFn, + DPOLossFn, +) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.algorithms.utils import ( calculate_kl_penalty_joschu2020, @@ -23,6 +27,22 @@ ) +def setup_dpo_loss_test_data(vocab_size=16, batch_size=1): + seq_len = 4 + data = { + "input_ids": torch.arange(vocab_size / 2) + .reshape(2 * batch_size, 4) + .to(torch.int64) + .to("cuda"), + "token_mask": torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]]).to("cuda"), + "sample_mask": torch.tensor([1, 1]).to("cuda"), + "reference_policy_logprobs": torch.zeros((2 * batch_size, seq_len)).to("cuda"), + } + + next_token_logits = torch.zeros((2 * batch_size, seq_len, vocab_size)).to("cuda") + return data, next_token_logits + + def test_nll_loss(): if not torch.cuda.is_available(): pytest.skip("No GPU available") @@ -79,6 +99,200 @@ def test_nll_loss(): assert metrics_dict["total_tokens"] == 3 +def test_dpo_loss(): + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + vocab_size = 16 + batch_size = 1 + num_unmasked_tokens = 2 + data, next_token_logits = setup_dpo_loss_test_data( + vocab_size=vocab_size, + batch_size=batch_size, + ) + loss_fn = DPOLossFn( + cfg={ + "reference_policy_kl_penalty": 0.0, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.0, + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + + loss, metrics_dict = loss_fn( + next_token_logits, + data, + ) + + ## chosen and rejected errors are the same, so difference between them is 0 + assert torch.isclose(loss.cpu(), -torch.nn.functional.logsigmoid(torch.tensor(0.0))) + + loss_fn_with_sft = DPOLossFn( + cfg={ + "reference_policy_kl_penalty": 0.0, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.5, + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + + expected_sft_loss = ( + -( + torch.nn.functional.log_softmax(torch.tensor([[0.0] * vocab_size]), dim=-1)[ + :, 0 + ].sum() + ) + * num_unmasked_tokens + * batch_size + ) + expected_preference_loss = -torch.nn.functional.logsigmoid(torch.tensor(0.0)) + assert torch.isclose( + loss_fn_with_sft(next_token_logits, data)[0].cpu(), + 0.5 * expected_sft_loss + expected_preference_loss, + ) + + +def test_dpo_loss_varying_sequence_lengths(): + """Test DPO loss with varying sequence lengths and preference_average_log_probs=True.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + # Create DPO loss function with preference_average_log_probs=True + dpo_loss_fn_no_avg = DPOLossFn( + { + "reference_policy_kl_penalty": 0.1, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.5, + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + dpo_loss_fn_avg = DPOLossFn( + { + "reference_policy_kl_penalty": 0.1, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.5, + "preference_average_log_probs": True, + "sft_average_log_probs": True, + } + ) + + # Create test data with varying sequence lengths + # Batch size 4 (2 pairs of chosen/rejected) + # Sequence lengths: [3, 5, 4, 6] + batch_size = 4 + max_seq_len = 6 + vocab_size = 10 + + # Create input_ids with varying lengths + input_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long).to("cuda") + input_ids[0, :3] = torch.arange(3) # length 3 + input_ids[1, :5] = torch.arange(5) # length 5 + input_ids[2, :4] = torch.arange(4) # length 4 + input_ids[3, :6] = torch.arange(6) # length 6 + + # Create token masks based on sequence lengths + token_mask = torch.zeros((batch_size, max_seq_len)).to("cuda") + token_mask[0, :3] = 1.0 + token_mask[1, :5] = 1.0 + token_mask[2, :4] = 1.0 + token_mask[3, :6] = 1.0 + + # Create sample mask (all valid) + sample_mask = torch.ones(batch_size).to("cuda") + + # Create reference policy logprobs + # Make chosen responses have slightly higher logprobs than rejected + reference_policy_logprobs = torch.zeros((batch_size, max_seq_len)).to("cuda") + # Create next token logits + next_token_logits = torch.zeros((batch_size, max_seq_len, vocab_size)).to("cuda") + + # Create batched data dictionary + data = BatchedDataDict( + { + "input_ids": input_ids, + "reference_policy_logprobs": reference_policy_logprobs, + "token_mask": token_mask, + "sample_mask": sample_mask, + } + ) + + # Compute loss + loss, metrics = dpo_loss_fn_no_avg(next_token_logits, data) + loss_avg, metrics_avg = dpo_loss_fn_avg(next_token_logits, data) + + num_unmasked_tokens = token_mask[:, 1:][::2].sum().item() + logprobs = torch.nn.functional.log_softmax(next_token_logits[:, 1:], dim=-1) + token_logprobs = logprobs.gather( + dim=-1, index=input_ids[:, 1:].unsqueeze(-1) + ).squeeze(-1) + expected_per_token_sft_loss = -(token_logprobs[::2] * token_mask[:, 1:][::2]) + ## sum across tokens in an example, average across examples + expected_sft_loss_no_avg = expected_per_token_sft_loss.sum(-1).mean() + ## average across tokens in an example, then average across examples + expected_sft_loss_avg = expected_per_token_sft_loss.sum() / num_unmasked_tokens + + assert torch.isclose(torch.tensor(metrics["sft_loss"]), expected_sft_loss_no_avg) + assert torch.isclose(torch.tensor(metrics_avg["sft_loss"]), expected_sft_loss_avg) + + +def test_dpo_sft_matches_nll_loss(): + """Test that DPO SFT loss matches NLL loss when preference_loss_weight=0.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + # Setup test data + vocab_size = 8 + batch_size = 2 + dpo_data = { + "input_ids": torch.randint(0, vocab_size, (batch_size * 2, 5)) + .to(torch.int64) + .to("cuda"), + "token_mask": torch.tensor( + [[0, 0, 1, 1, 0], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0]] + ).to("cuda"), + "sample_mask": torch.tensor([1, 1, 1, 1]).to("cuda"), + "reference_policy_logprobs": torch.randn((4, 5)).to("cuda"), + } + + ## when computing the sft loss in DPO, we only use the chosen samples + sft_data = { + "input_ids": dpo_data["input_ids"][::2], + "token_mask": dpo_data["token_mask"][::2], + "sample_mask": dpo_data["sample_mask"][::2], + } + + # Create next token logits that will give non-zero loss + ## * 2 for chosen/rejected + next_token_logits = torch.randn((batch_size * 2, 5, vocab_size)).to("cuda") + + # Compute NLL loss + nll_loss_fn = NLLLoss() + nll_loss, nll_metrics = nll_loss_fn(next_token_logits[::2], sft_data) + + # Compute DPO loss with preference_loss_weight=0 + dpo_loss_fn = DPOLossFn( + cfg={ + "reference_policy_kl_penalty": 0.0, + "preference_loss_weight": 0.0, # Disable preference loss + "sft_loss_weight": 1.0, # Only use SFT loss + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + dpo_loss, dpo_metrics = dpo_loss_fn(next_token_logits, dpo_data) + + # Verify losses match + ## since DPO SFT loss just sums across tokens in a batch and then averages over the batch, + ## we need to re-normalize by multiplying by the batch size and dividing by the total number + ## of unmasked chosen tokens + torch.testing.assert_close( + dpo_loss / torch.sum(dpo_data["token_mask"][::2]) * batch_size, nll_loss + ) + + def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): """Sets up basic mock data structure. Tests should fill values.""" input_ids = torch.randint( # Input IDs only needed if original loss fn used diff --git a/tests/unit/data/hf_datasets/test_dpo_dataset.py b/tests/unit/data/hf_datasets/test_dpo_dataset.py new file mode 100644 index 0000000000..19d9d45ef6 --- /dev/null +++ b/tests/unit/data/hf_datasets/test_dpo_dataset.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import json +import pytest +from unittest.mock import patch, MagicMock + +from nemo_reinforcer.data.hf_datasets.dpo import DPODataset + + +@pytest.fixture +def mock_dpo_data(): + """Create temporary DPO dataset files with sample data.""" + train_data = [ + { + "prompt": "What is 2+2?", + "chosen_response": "The answer is 4.", + "rejected_response": "I don't know.", + }, + { + "prompt": "What is the capital of France?", + "chosen_response": "The capital of France is Paris.", + "rejected_response": "The capital of France is London.", + }, + ] + + val_data = [ + { + "prompt": "What is 3*3?", + "chosen_response": "The answer is 9.", + "rejected_response": "The answer is 6.", + } + ] + + train_ctx = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) + val_ctx = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as train_file: + json.dump(train_data, train_file) + train_path = train_file.name + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as val_file: + json.dump(val_data, val_file) + val_path = val_file.name + yield train_path, val_path + # Cleanup + os.unlink(train_path) + os.unlink(val_path) + + +def test_dpo_dataset_initialization(mock_dpo_data): + """Test that DPODataset initializes correctly with valid data files.""" + train_path, val_path = mock_dpo_data + + dataset = DPODataset(train_data_path=train_path, val_data_path=val_path) + + # Verify dataset initialization + assert dataset.task_spec.task_name == "DPO" + + # Verify formatted_ds structure + assert "train" in dataset.formatted_ds + assert "validation" in dataset.formatted_ds + + assert len(dataset.formatted_ds["train"]) == 2 + assert len(dataset.formatted_ds["validation"]) == 1 + + +def test_dpo_dataset_invalid_files(): + """Test that DPODataset raises appropriate errors with invalid files.""" + with pytest.raises(FileNotFoundError): + DPODataset(train_data_path="nonexistent.json", val_data_path="nonexistent.json") + + +def test_dpo_dataset_data_format(mock_dpo_data): + """Test that DPODataset correctly formats the data.""" + train_path, val_path = mock_dpo_data + dataset = DPODataset(train_data_path=train_path, val_data_path=val_path) + + # Verify data format + train_sample = dataset.formatted_ds["train"][0] + assert "prompt" in train_sample + assert "chosen_response" in train_sample + assert "rejected_response" in train_sample + + # Verify data content + assert train_sample["prompt"] == "What is 2+2?" + assert train_sample["chosen_response"] == "The answer is 4." + assert train_sample["rejected_response"] == "I don't know." diff --git a/tests/unit/data/hf_datasets/test_helpsteer.py b/tests/unit/data/hf_datasets/test_helpsteer.py new file mode 100644 index 0000000000..304fd5d2ad --- /dev/null +++ b/tests/unit/data/hf_datasets/test_helpsteer.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from nemo_reinforcer.data.hf_datasets.helpsteer3 import ( + HelpSteer3Dataset, + format_helpsteer3, +) + + +def test_format_helpsteer3(): + """Test the format_helpsteer3 function with different preference values.""" + # Test case 1: response1 is preferred (overall_preference < 0) + data1 = { + "context": "What is 2+2?", + "response1": "The answer is 4.", + "response2": "I don't know.", + "overall_preference": -1, + } + result1 = format_helpsteer3(data1) + assert result1["prompt"] == "What is 2+2?" + assert result1["chosen_response"] == "The answer is 4." + assert result1["rejected_response"] == "I don't know." + + # Test case 2: response2 is preferred (overall_preference > 0) + data2 = { + "context": "What is the capital of France?", + "response1": "The capital of France is London.", + "response2": "The capital of France is Paris.", + "overall_preference": 1, + } + result2 = format_helpsteer3(data2) + assert result2["prompt"] == "What is the capital of France?" + assert result2["chosen_response"] == "The capital of France is Paris." + assert result2["rejected_response"] == "The capital of France is London." + + # Test case 3: no preference (overall_preference = 0) + data3 = { + "context": "What is the weather like?", + "response1": "It's sunny today.", + "response2": "The weather is sunny.", + "overall_preference": 0, + } + result3 = format_helpsteer3(data3) + assert result3["prompt"] == "What is the weather like?" + # When preference is 0, neither response is preferred, so + # response 1 is used for both chosen and rejected + assert result3["chosen_response"] == "It's sunny today." + assert result3["rejected_response"] == "It's sunny today." + + +def test_helpsteer3_dataset_initialization(): + """Test that HelpSteer3Dataset initializes correctly.""" + + dataset = HelpSteer3Dataset() + + # Verify dataset initialization + assert dataset.task_spec.task_name == "HelpSteer3" + + +def test_helpsteer3_dataset_data_format(): + """Test that HelpSteer3Dataset correctly formats the data.""" + + dataset = HelpSteer3Dataset() + + assert isinstance(dataset.formatted_ds, dict) + assert "train" in dataset.formatted_ds + assert "validation" in dataset.formatted_ds + + # Verify data format + sample = dataset.formatted_ds["train"][0] + assert "prompt" in sample + assert "chosen_response" in sample + assert "rejected_response" in sample diff --git a/tests/unit/data/test_datasets.py b/tests/unit/data/test_datasets.py new file mode 100755 index 0000000000..7486e025c4 --- /dev/null +++ b/tests/unit/data/test_datasets.py @@ -0,0 +1,148 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from unittest.mock import MagicMock + +from nemo_reinforcer.data.datasets import dpo_collate_fn +from nemo_reinforcer.data.interfaces import DatumSpec +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +def test_dpo_collate_fn(): + """Test that dpo_collate_fn correctly processes DPO training data.""" + # Create mock tokenizer + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + # Create test data with varying sequence lengths + data_batch = [ + DatumSpec( + message_log_chosen=[ + { + "role": "user", + "content": "Hello", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "assistant", + "content": "Hi there", + "token_ids": torch.tensor([4, 5, 6, 7]), + }, + ], + message_log_rejected=[ + { + "role": "user", + "content": "Hello", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "assistant", + "content": "Bye", + "token_ids": torch.tensor([8, 9]), + }, + ], + length_chosen=7, + length_rejected=5, + loss_multiplier=1.0, + idx=0, + task_name="test_task", + ), + DatumSpec( + message_log_chosen=[ + { + "role": "user", + "content": "How are you?", + "token_ids": torch.tensor([10, 11, 12]), + }, + { + "role": "assistant", + "content": "I'm good", + "token_ids": torch.tensor([13, 14, 15]), + }, + ], + message_log_rejected=[ + { + "role": "user", + "content": "How are you?", + "token_ids": torch.tensor([10, 11, 12]), + }, + { + "role": "assistant", + "content": "Not great", + "token_ids": torch.tensor([16, 17, 18, 19]), + }, + ], + length_chosen=6, + length_rejected=7, + loss_multiplier=0, + idx=1, + task_name="test_task", + ), + ] + + # Call dpo_collate_fn + train_data = dpo_collate_fn( + data_batch, mock_tokenizer, make_sequence_length_divisible_by=16 + ) + + # Verify the output structure + assert isinstance(train_data, BatchedDataDict) + assert "input_ids" in train_data + assert "input_lengths" in train_data + assert "token_mask" in train_data + assert "sample_mask" in train_data + + # Verify batch size is doubled (chosen + rejected for each example) + assert train_data["input_ids"].shape[0] == 4 # 2 examples * 2 (chosen + rejected) + + # Verify input_ids shape and padding + max_length = 16 # max of all sequence lengths, padded to be divisible by 16 + assert train_data["input_ids"].shape == (4, max_length) + + # Verify input_lengths + expected_lengths = [7, 5, 6, 7] # chosen1, rejected1, chosen2, rejected2 + assert torch.equal(train_data["input_lengths"], torch.tensor(expected_lengths)) + + # Verify token_mask + assert train_data["token_mask"].shape == (4, max_length) + # First example chosen (length 7) + assert torch.all(train_data["token_mask"][0][0:3] == 0) + assert torch.all(train_data["token_mask"][0][3:7] == 1) + # First example rejected (length 5) + assert torch.all(train_data["token_mask"][1][0:3] == 0) + assert torch.all(train_data["token_mask"][1][3:5] == 1) + assert torch.all(train_data["token_mask"][1][5:] == 0) + + # Verify sample_mask + expected_sample_mask = [ + 1.0, + 1.0, + 0.0, + 0.0, + ] # loss_multiplier repeated for chosen/rejected + assert torch.equal(train_data["sample_mask"], torch.tensor(expected_sample_mask)) + + # Verify message content is preserved + # First example chosen + assert torch.equal(train_data["input_ids"][0][0:3], torch.tensor([1, 2, 3])) # user + assert torch.equal( + train_data["input_ids"][0][3:7], torch.tensor([4, 5, 6, 7]) + ) # assistant + # First example rejected + assert torch.equal(train_data["input_ids"][1][0:3], torch.tensor([1, 2, 3])) # user + assert torch.equal( + train_data["input_ids"][1][3:5], torch.tensor([8, 9]) + ) # assistant diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index a966f1d710..2ff25beb1c 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -438,6 +438,22 @@ def test_add_loss_mask_to_chat_message_log( tokenized_chat_message_log[0][2]["token_loss_mask"], torch.tensor([1, 1]) ) + ## test only unmasking final message + add_loss_mask_to_message_log( + tokenized_chat_message_log, + only_unmask_final=True, + ) + assert torch.equal( + tokenized_chat_message_log[0][0]["token_loss_mask"], + torch.tensor([0, 0, 0, 0, 0, 0]), + ) + assert torch.equal( + tokenized_chat_message_log[0][1]["token_loss_mask"], torch.tensor([0, 0, 0]) + ) + assert torch.equal( + tokenized_chat_message_log[0][2]["token_loss_mask"], torch.tensor([1, 1]) + ) + def test_get_first_index_that_differs(): assert get_first_index_that_differs("hello", "hello") == 5 diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2773fd20f2..9972c1a1b6 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -22,7 +22,10 @@ def simple_loss( ) -> Tuple[torch.Tensor, Dict[str, Any]]: # Just return mean of logprobs as the loss for testing loss = next_token_logits.mean() - metrics = {"test_metric": loss.item() * 0.5} + metrics = { + "test_metric": loss.item() * 0.5, + "num_valid_samples": 1, + } return loss, metrics @@ -46,4 +49,7 @@ def nll_loss( token_loss_mask = data.get("token_loss_mask")[:, 1:].cuda() loss = -torch.sum(token_logprobs * token_loss_mask) - return loss, {"loss": loss.item()} + return loss, { + "loss": loss.item(), + "num_valid_samples": 1, + }