Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
2e7bc59
initial dpo implementation
ashors1 Apr 4, 2025
f44a028
bug fixes
ashors1 Apr 4, 2025
33e8535
small perf gains
ashors1 Apr 7, 2025
1cb0784
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 Apr 8, 2025
b430180
make dpo work with jsonl
ashors1 Apr 9, 2025
4cc5a7f
add validation and checkpointing
ashors1 Apr 9, 2025
c136aa7
revert handling of too-long lines
ashors1 Apr 11, 2025
e2307ae
fix running validation with different batch size than training
ashors1 Apr 11, 2025
dbe01d0
changes for convergence testing
ashors1 Apr 12, 2025
3ff1a7e
no_grad and model.eval when in eval_mode
yfw Apr 13, 2025
b2b3c66
clean up, add support for average_log_probs
ashors1 Apr 14, 2025
8145e4c
drop_last during training
ashors1 Apr 14, 2025
78ae2eb
small fixes for checkpointing and multi-epoch
ashors1 Apr 14, 2025
c039b7d
add copyright
ashors1 Apr 14, 2025
0adec9b
clean up loss
ashors1 Apr 14, 2025
63bec20
fixes for helpsteer dataset
ashors1 Apr 14, 2025
444d5db
add loss function unit test
ashors1 Apr 14, 2025
e9dae56
cleanup
ashors1 Apr 14, 2025
8045ab9
add test for augment_dataloader
ashors1 Apr 14, 2025
db03bcc
add dpo collate test
ashors1 Apr 14, 2025
56044ab
add dataset unit tests
ashors1 Apr 14, 2025
0645b48
add DPO documentation
ashors1 Apr 14, 2025
bbcba50
fix example
ashors1 Apr 14, 2025
5d58614
cleanup and add docstrings
ashors1 Apr 14, 2025
19c0461
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 Apr 14, 2025
d8b767a
add another test, clean up
ashors1 Apr 15, 2025
7d03127
update config
ashors1 Apr 15, 2025
18faa44
rename test
ashors1 Apr 15, 2025
7d97f30
fix test
ashors1 Apr 15, 2025
7a1e190
address comments and update config
ashors1 Apr 15, 2025
0f70438
add one more unit test
ashors1 Apr 15, 2025
21fd438
minor readme update
ashors1 Apr 15, 2025
398b17b
add note on gbs to config
ashors1 Apr 15, 2025
3f5276a
fix functional test
ashors1 Apr 15, 2025
6e870a1
small fixes
ashors1 Apr 15, 2025
38e5264
decrease num steps
ashors1 Apr 16, 2025
e7deda8
fix DPO validation and correctly handle samples that are longer than …
ashors1 Apr 16, 2025
7fbf847
fix comment
ashors1 Apr 16, 2025
fb76ff4
fix reduction over valid samples
ashors1 Apr 16, 2025
7472f6b
small bug fixes
ashors1 Apr 17, 2025
ace13fa
log sum of valid samples rather than average
ashors1 Apr 17, 2025
b5c66ba
address some comments
ashors1 Apr 17, 2025
fea1e22
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 Apr 17, 2025
6124416
small gbs and mbs fix, add copyright
ashors1 Apr 17, 2025
6962df1
small fixes following rebase
ashors1 Apr 17, 2025
d38891a
address comments, fix tests after rebase
ashors1 Apr 17, 2025
a7b2dec
add hydra-style overrides
ashors1 Apr 17, 2025
4ac99e3
sum valid samples across batch
ashors1 Apr 17, 2025
82e8e98
decrease max steps and fix test
ashors1 Apr 17, 2025
233a9ab
address remaining comments
ashors1 Apr 17, 2025
30e7140
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 Apr 17, 2025
c22bc60
support dtensor with dpo
ashors1 Apr 17, 2025
96c079c
minor bug fixes
ashors1 Apr 17, 2025
5a81c63
fix test loss fn
ashors1 Apr 18, 2025
a8b3efc
update dpo docs
ashors1 Apr 18, 2025
4c90c44
fix indentation
ashors1 Apr 18, 2025
680dfbc
address remaining comments
ashors1 Apr 18, 2025
b54b6c3
fix hyperlinks
ashors1 Apr 18, 2025
4612a7a
small readme fixes
ashors1 Apr 18, 2025
458f9c6
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 Apr 22, 2025
fbc1f6d
fix issues with merge
ashors1 Apr 22, 2025
01c4f90
fix issue with rebase, add functional dpo test to ci
ashors1 Apr 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ jobs:
if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then
uv run --no-sync bash ./tests/functional/sft.sh
uv run --no-sync bash ./tests/functional/grpo.sh
uv run --no-sync bash ./tests/functional/dpo.sh
else
echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }}
fi
Expand Down
106 changes: 84 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
- [Features](#features)
- [Prerequisuites](#prerequisuites)
- [Quick start](#quick-start)
- [SFT](#sft)
- [GRPO](#grpo)
- [Single Node](#single-node)
- [Multi-node](#multi-node)
- [GRPO](#grpo)
- [SFT](#sft)
- [Single Node](#single-node-1)
- [Multi-node](#multi-node-1)
- [DPO](#dpo)
- [Single Node](#single-node-2)
- [Multi-node](#multi-node-2)
- [Cluster Start](#cluster-start)

**Nemo-Reinforcer** is a scalable and efficient post-training library designed for models ranging from 1 GPU to thousands, and from tiny to over 100 billion parameters.
Expand All @@ -33,10 +36,10 @@ What you can expect:
- ✅ **Environment Support** - Support for multi-environment training.
- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning)
- ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state)
- ✅ **DPO Algorithm** - Direct Preference Optimization for alignment
- 🔜 **Larger Model Support** - Native PyTorch support for models up to 70B parameters
- 🔜 **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training
- 🔜 **Environment Isolation** - Dependency isolation between components
- 🔜 **DPO Algorithm** - Direct Preference Optimization for alignment

## Prerequisuites

Expand All @@ -59,6 +62,61 @@ pip install uv

**Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models.

### GRPO

We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset.

#### Single Node

To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`:

```sh
# Run the GRPO math example using a 1B parameter model
uv run python examples/run_grpo_math.py
```

By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus,

```sh
# Run the GRPO math example using a 1B parameter model using 8 GPUs
uv run python examples/run_grpo_math.py \
cluster.gpus_per_node=8
```

You can override any of the parameters listed in the yaml configuration file. For example,

```sh
uv run python examples/run_grpo_math.py \
policy.model_name="Qwen/Qwen2-1.5B" \
checkpointing.checkpoint_dir="results/qwen1_5b_math" \
logger.wandb_enabled=True \
logger.wandb.name="grpo-qwen1_5b_math" \
logger.num_val_samples_to_print=10 \
```

#### Multi-node

```sh
# Run from the root of NeMo-Reinforcer repo
NUM_ACTOR_NODES=2
# Add a timestamp to make each job name unique
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# grpo_math_8b uses Llama-3.1-8B-Instruct model
COMMAND="uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \
UV_CACHE_DIR=YOUR_UV_CACHE_DIR \
CONTAINER=YOUR_CONTAINER \
MOUNTS="$PWD:$PWD" \
sbatch \
--nodes=${NUM_ACTOR_NODES} \
--account=YOUR_ACCOUNT \
--job-name=YOUR_JOBNAME \
--partition=YOUR_PARTITION \
--time=4:0:0 \
--gres=gpu:8 \
ray.sub
```

### SFT

We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/).
Expand Down Expand Up @@ -87,15 +145,12 @@ Refer to `examples/configs/sft.yaml` for a full list of parameters that can be o

#### Multi-node

For distributed training across multiple nodes:

```sh
# Run from the root of NeMo-Reinforcer repo
NUM_ACTOR_NODES=2
# Add a timestamp to make each job name unique
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# SFT experiment uses Llama-3.1-8B model
COMMAND="uv run ./examples/run_sft.py --config examples/configs/sft.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/sft_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='sft-llama8b'" \
CONTAINER=YOUR_CONTAINER \
MOUNTS="$PWD:$PWD" \
Expand All @@ -109,48 +164,55 @@ sbatch \
ray.sub
```

### GRPO
### DPO

We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset.
We provide a sample DPO experiment that uses the [HelpSteer3 dataset](https://huggingface.co/datasets/nvidia/HelpSteer3) for preference-based training.

#### Single Node

To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`:
The default DPO experiment is configured to run on a single GPU. To launch the experiment:

```sh
# Run the GRPO math example using a 1B parameter model
uv run python examples/run_grpo_math.py
uv run python examples/run_dpo.py
```

By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus,
This trains `Llama3.2-1B-Instruct` on one GPU.

If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama3.1 Instruct model:

```sh
# Run the GRPO math example using a 1B parameter model using 8 GPUs
uv run python examples/run_grpo_math.py \
uv run python examples/run_dpo.py \
policy.model_name="meta-llama/Llama-3.1-8B-Instruct" \
policy.train_global_batch_size=256 \
cluster.gpus_per_node=8
```

You can override any of the parameters listed in the yaml configuration file. For example,
Any of the DPO parameters can be customized from the command line. For example:

```sh
uv run python examples/run_grpo_math.py \
policy.model_name="Qwen/Qwen2-1.5B" \
checkpointing.checkpoint_dir="results/qwen1_5b_math" \
uv run python examples/run_dpo.py \
dpo.sft_loss_weight=0.1 \
dpo.preference_average_log_probs=True \
checkpointing.checkpoint_dir="results/llama_dpo_sft" \
logger.wandb_enabled=True \
logger.wandb.name="grpo-qwen1_5b_math" \
logger.num_val_samples_to_print=10 \
logger.wandb.name="llama-dpo-sft"
```

Refer to [dpo.yaml](examples/configs/dpo.yaml) for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md).

#### Multi-node

For distributed DPO training across multiple nodes, modify the following script for your use case:

```sh
# Run from the root of NeMo-Reinforcer repo
## number of nodes to use for your job
NUM_ACTOR_NODES=2
# Add a timestamp to make each job name unique
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# grpo_math_8b uses Llama-3.1-8B-Instruct model
COMMAND="uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \
COMMAND="uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 dpo.val_global_batch_size=32 checkpointing.checkpoint_dir='results/dpo_llama81_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama1b'" \
RAY_DEDUP_LOGS=0 \
CONTAINER=YOUR_CONTAINER \
MOUNTS="$PWD:$PWD" \
sbatch \
Expand Down
169 changes: 169 additions & 0 deletions docs/guides/dpo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Direct Preference Optimization in Reinforcer

[Direct Preference Optimization (DPO)](https://arxiv.org/pdf/2305.18290) is an RL-free alignment algorithm that operates on preference data. Given a prompt and a pair of chosen and rejected responses, DPO aims
to increase the probability of the chosen response and decrease the probability of the rejected response relative to a frozen reference model. The actor is initialized using the reference model. For more details, refer to the
[DPO paper](https://arxiv.org/pdf/2305.18290).

## Launch a DPO Run

The script [examples/run_dpo.py](../../examples/run_dpo.py) can be used to launch a DPO experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md).

Be sure to launch the job using `uv`. The command to launch a DPO job is as follows:
```bash
uv run examples/run_dpo.py --config <PATH TO YAML CONFIG> <OVERRIDES>
```
If not specified, `config` will default to [examples/configs/dpo.yaml](../../examples/configs/dpo.yaml).

## Configuration

Reinforcer allows users to configure DPO experiments using `yaml` config files. An example DPO configuration file can be found [here](../../examples/configs/dpo.yaml).

To override a value in the config, either update the value in the `yaml` file directly, or pass the override via the command line. For example:

```bash
uv run examples/run_dpo.py \
cluster.gpus_per_node=8 \
dpo.sft_loss_weight=0.1 \
dpo.preference_average_log_probs=True \
logger.wandb.name="dpo-dev-8-gpu"
```

**Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models.

## Datasets

Each class representing a Reinforcer DPO dataset is expected to have the following attributes:
1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below.
2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset.

DPO datasets are expected to follow a specific format with three key fields:
- `prompt`: The input prompt/context
- `chosen_response`: The preferred/winning response
- `rejected_response`: The non-preferred/losing response

[data/hf_datasets/helpsteer3.py](../../nemo_reinforcer/data/hf_datasets/helpsteer3.py) provides an example of how to format data for DPO:

```python
def format_helpsteer3(data):
response_1 = data["response1"]
response_2 = data["response2"]
overall_preference = data["overall_preference"]

if overall_preference < 0:
chosen = response_1
rejected = response_2
elif overall_preference == 0:
chosen = response_1
rejected = response_1
else:
chosen = response_2
rejected = response_1

return {
"prompt": data["context"],
"chosen_response": chosen,
"rejected_response": rejected,
}
```

We also provide a [DPODataset](../../nemo_reinforcer/data/hf_datasets/dpo.py) class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys.

## Adding Custom DPO Datasets

Adding a new DPO dataset is straightforward. Your custom dataset class should:
1. Implement the required format conversion in the constructor
2. Set up the appropriate `task_spec`

Here's a minimal example which simply re-keys an existing jsonl dataset:

```{testcode}
from datasets import load_dataset
from nemo_reinforcer.data.interfaces import TaskDataSpec
from docs.helpers import make_dpo_dataset

class CustomDPODataset:
def preprocess_dataset(
self,
data,
prompt_key: str = "context",
chosen_key: str = "chosen",
rejected_key: str = "rejected"
):
return {
"prompt": data[prompt_key],
"chosen_response": data[chosen_key],
"rejected_response": data[rejected_key],
}

def __init__(
self,
train_data_path: str,
val_data_path: str,
prompt_key: str,
chosen_key: str,
rejected_key: str,
):
# Load and format your dataset
fn_kwargs={
"prompt_key": prompt_key,
"chosen_key": chosen_key,
"rejected_key": rejected_key
}
formatted_ds = {
"train": load_dataset("json", data_files=train_data_path, split="train").map(
self.preprocess_dataset,
fn_kwargs=fn_kwargs,
),
"validation": load_dataset("json", data_files=val_data_path, split="train").map(
self.preprocess_dataset,
fn_kwargs=fn_kwargs,
),
}

# Initialize task spec with dataset name
self.task_spec = TaskDataSpec(
task_name="custom_dpo",
)
self.formatted_ds = formatted_ds

# Create temporary files using helper function
train_file, val_file = make_dpo_dataset()

# Initialize dataset
dataset = CustomDPODataset(
train_data_path=train_file.name,
val_data_path=val_file.name,
prompt_key="context",
chosen_key="chosen",
rejected_key="rejected"
)

# Test dataset properties
print(f"Task name: {dataset.task_spec.task_name}")
print(f"Train examples: {len(dataset.formatted_ds['train'])}")
print(f"Validation examples: {len(dataset.formatted_ds['validation'])}")
print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}")
print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}")
print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}")
```

```{testoutput}
Task name: custom_dpo
Train examples: 2
Validation examples: 2
First train example prompt: What is 2+2?
First train example chosen response: 4
First train example rejected response: 5
```

## DPO-Specific Parameters

The DPO implementation in Reinforcer supports several key parameters that can be adjusted:

- `dpo.reference_policy_kl_penalty`: Controls the strength of the KL penalty term
- `dpo.preference_loss_weight`: Weight for the preference loss
- `dpo.sft_loss_weight`: Weight for the auxiliary SFT loss
- `dpo.preference_average_log_probs`: Whether to average log probabilities over tokens in the preference loss term
- `dpo.sft_average_log_probs`: Whether to average log probabilities over tokens in the SFT loss term

These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case.
41 changes: 41 additions & 0 deletions docs/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import json


def make_dpo_dataset():
train_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
val_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)

# Write train data
train_data = [
{"context": "What is 2+2?", "chosen": "4", "rejected": "5"},
{"context": "What is 3*3?", "chosen": "9", "rejected": "6"},
]
for item in train_data:
lines = train_file.write(json.dumps(item) + "\n")
train_file.flush()

# Write validation data
val_data = [
{"context": "What is 4+4?", "chosen": "8", "rejected": "7"},
{"context": "What is 5*5?", "chosen": "25", "rejected": "20"},
]
for item in val_data:
lines = val_file.write(json.dumps(item) + "\n")
val_file.flush()

return train_file, val_file
Loading
Loading