-
Notifications
You must be signed in to change notification settings - Fork 307
feat: DPO #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: DPO #180
Changes from all commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
2e7bc59
initial dpo implementation
ashors1 f44a028
bug fixes
ashors1 33e8535
small perf gains
ashors1 1cb0784
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 b430180
make dpo work with jsonl
ashors1 4cc5a7f
add validation and checkpointing
ashors1 c136aa7
revert handling of too-long lines
ashors1 e2307ae
fix running validation with different batch size than training
ashors1 dbe01d0
changes for convergence testing
ashors1 3ff1a7e
no_grad and model.eval when in eval_mode
yfw b2b3c66
clean up, add support for average_log_probs
ashors1 8145e4c
drop_last during training
ashors1 78ae2eb
small fixes for checkpointing and multi-epoch
ashors1 c039b7d
add copyright
ashors1 0adec9b
clean up loss
ashors1 63bec20
fixes for helpsteer dataset
ashors1 444d5db
add loss function unit test
ashors1 e9dae56
cleanup
ashors1 8045ab9
add test for augment_dataloader
ashors1 db03bcc
add dpo collate test
ashors1 56044ab
add dataset unit tests
ashors1 0645b48
add DPO documentation
ashors1 bbcba50
fix example
ashors1 5d58614
cleanup and add docstrings
ashors1 19c0461
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 d8b767a
add another test, clean up
ashors1 7d03127
update config
ashors1 18faa44
rename test
ashors1 7d97f30
fix test
ashors1 7a1e190
address comments and update config
ashors1 0f70438
add one more unit test
ashors1 21fd438
minor readme update
ashors1 398b17b
add note on gbs to config
ashors1 3f5276a
fix functional test
ashors1 6e870a1
small fixes
ashors1 38e5264
decrease num steps
ashors1 e7deda8
fix DPO validation and correctly handle samples that are longer than …
ashors1 7fbf847
fix comment
ashors1 fb76ff4
fix reduction over valid samples
ashors1 7472f6b
small bug fixes
ashors1 ace13fa
log sum of valid samples rather than average
ashors1 b5c66ba
address some comments
ashors1 fea1e22
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 6124416
small gbs and mbs fix, add copyright
ashors1 6962df1
small fixes following rebase
ashors1 d38891a
address comments, fix tests after rebase
ashors1 a7b2dec
add hydra-style overrides
ashors1 4ac99e3
sum valid samples across batch
ashors1 82e8e98
decrease max steps and fix test
ashors1 233a9ab
address remaining comments
ashors1 30e7140
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 c22bc60
support dtensor with dpo
ashors1 96c079c
minor bug fixes
ashors1 5a81c63
fix test loss fn
ashors1 a8b3efc
update dpo docs
ashors1 4c90c44
fix indentation
ashors1 680dfbc
address remaining comments
ashors1 b54b6c3
fix hyperlinks
ashors1 4612a7a
small readme fixes
ashors1 458f9c6
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dpo
ashors1 fbc1f6d
fix issues with merge
ashors1 01c4f90
fix issue with rebase, add functional dpo test to ci
ashors1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ashors1 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # Direct Preference Optimization in Reinforcer | ||
parthchadha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| [Direct Preference Optimization (DPO)](https://arxiv.org/pdf/2305.18290) is an RL-free alignment algorithm that operates on preference data. Given a prompt and a pair of chosen and rejected responses, DPO aims | ||
| to increase the probability of the chosen response and decrease the probability of the rejected response relative to a frozen reference model. The actor is initialized using the reference model. For more details, refer to the | ||
| [DPO paper](https://arxiv.org/pdf/2305.18290). | ||
|
|
||
| ## Launch a DPO Run | ||
|
|
||
| The script [examples/run_dpo.py](../../examples/run_dpo.py) can be used to launch a DPO experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md). | ||
|
|
||
| Be sure to launch the job using `uv`. The command to launch a DPO job is as follows: | ||
| ```bash | ||
| uv run examples/run_dpo.py --config <PATH TO YAML CONFIG> <OVERRIDES> | ||
| ``` | ||
| If not specified, `config` will default to [examples/configs/dpo.yaml](../../examples/configs/dpo.yaml). | ||
|
|
||
| ## Configuration | ||
|
|
||
| Reinforcer allows users to configure DPO experiments using `yaml` config files. An example DPO configuration file can be found [here](../../examples/configs/dpo.yaml). | ||
|
|
||
| To override a value in the config, either update the value in the `yaml` file directly, or pass the override via the command line. For example: | ||
|
|
||
| ```bash | ||
| uv run examples/run_dpo.py \ | ||
| cluster.gpus_per_node=8 \ | ||
| dpo.sft_loss_weight=0.1 \ | ||
| dpo.preference_average_log_probs=True \ | ||
| logger.wandb.name="dpo-dev-8-gpu" | ||
| ``` | ||
|
|
||
| **Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. | ||
|
|
||
| ## Datasets | ||
|
|
||
| Each class representing a Reinforcer DPO dataset is expected to have the following attributes: | ||
| 1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below. | ||
| 2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. | ||
|
|
||
| DPO datasets are expected to follow a specific format with three key fields: | ||
| - `prompt`: The input prompt/context | ||
| - `chosen_response`: The preferred/winning response | ||
| - `rejected_response`: The non-preferred/losing response | ||
|
|
||
| [data/hf_datasets/helpsteer3.py](../../nemo_reinforcer/data/hf_datasets/helpsteer3.py) provides an example of how to format data for DPO: | ||
|
|
||
| ```python | ||
| def format_helpsteer3(data): | ||
| response_1 = data["response1"] | ||
| response_2 = data["response2"] | ||
| overall_preference = data["overall_preference"] | ||
|
|
||
| if overall_preference < 0: | ||
| chosen = response_1 | ||
| rejected = response_2 | ||
| elif overall_preference == 0: | ||
| chosen = response_1 | ||
| rejected = response_1 | ||
| else: | ||
| chosen = response_2 | ||
| rejected = response_1 | ||
|
|
||
| return { | ||
| "prompt": data["context"], | ||
| "chosen_response": chosen, | ||
| "rejected_response": rejected, | ||
| } | ||
| ``` | ||
|
|
||
| We also provide a [DPODataset](../../nemo_reinforcer/data/hf_datasets/dpo.py) class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys. | ||
|
|
||
| ## Adding Custom DPO Datasets | ||
|
|
||
| Adding a new DPO dataset is straightforward. Your custom dataset class should: | ||
| 1. Implement the required format conversion in the constructor | ||
| 2. Set up the appropriate `task_spec` | ||
|
|
||
| Here's a minimal example which simply re-keys an existing jsonl dataset: | ||
|
|
||
| ```{testcode} | ||
| from datasets import load_dataset | ||
terrykong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from nemo_reinforcer.data.interfaces import TaskDataSpec | ||
| from docs.helpers import make_dpo_dataset | ||
|
|
||
| class CustomDPODataset: | ||
| def preprocess_dataset( | ||
| self, | ||
| data, | ||
| prompt_key: str = "context", | ||
| chosen_key: str = "chosen", | ||
| rejected_key: str = "rejected" | ||
| ): | ||
| return { | ||
| "prompt": data[prompt_key], | ||
| "chosen_response": data[chosen_key], | ||
| "rejected_response": data[rejected_key], | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| train_data_path: str, | ||
| val_data_path: str, | ||
| prompt_key: str, | ||
| chosen_key: str, | ||
| rejected_key: str, | ||
| ): | ||
| # Load and format your dataset | ||
| fn_kwargs={ | ||
| "prompt_key": prompt_key, | ||
| "chosen_key": chosen_key, | ||
| "rejected_key": rejected_key | ||
| } | ||
| formatted_ds = { | ||
| "train": load_dataset("json", data_files=train_data_path, split="train").map( | ||
| self.preprocess_dataset, | ||
| fn_kwargs=fn_kwargs, | ||
| ), | ||
| "validation": load_dataset("json", data_files=val_data_path, split="train").map( | ||
| self.preprocess_dataset, | ||
| fn_kwargs=fn_kwargs, | ||
| ), | ||
| } | ||
|
|
||
| # Initialize task spec with dataset name | ||
| self.task_spec = TaskDataSpec( | ||
| task_name="custom_dpo", | ||
| ) | ||
| self.formatted_ds = formatted_ds | ||
|
|
||
| # Create temporary files using helper function | ||
| train_file, val_file = make_dpo_dataset() | ||
|
|
||
| # Initialize dataset | ||
| dataset = CustomDPODataset( | ||
| train_data_path=train_file.name, | ||
| val_data_path=val_file.name, | ||
| prompt_key="context", | ||
| chosen_key="chosen", | ||
| rejected_key="rejected" | ||
| ) | ||
|
|
||
| # Test dataset properties | ||
| print(f"Task name: {dataset.task_spec.task_name}") | ||
| print(f"Train examples: {len(dataset.formatted_ds['train'])}") | ||
| print(f"Validation examples: {len(dataset.formatted_ds['validation'])}") | ||
| print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}") | ||
| print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}") | ||
| print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}") | ||
| ``` | ||
|
|
||
| ```{testoutput} | ||
| Task name: custom_dpo | ||
| Train examples: 2 | ||
| Validation examples: 2 | ||
| First train example prompt: What is 2+2? | ||
| First train example chosen response: 4 | ||
| First train example rejected response: 5 | ||
| ``` | ||
|
|
||
| ## DPO-Specific Parameters | ||
|
|
||
| The DPO implementation in Reinforcer supports several key parameters that can be adjusted: | ||
|
|
||
| - `dpo.reference_policy_kl_penalty`: Controls the strength of the KL penalty term | ||
| - `dpo.preference_loss_weight`: Weight for the preference loss | ||
| - `dpo.sft_loss_weight`: Weight for the auxiliary SFT loss | ||
| - `dpo.preference_average_log_probs`: Whether to average log probabilities over tokens in the preference loss term | ||
| - `dpo.sft_average_log_probs`: Whether to average log probabilities over tokens in the SFT loss term | ||
terrykong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import tempfile | ||
| import json | ||
|
|
||
|
|
||
| def make_dpo_dataset(): | ||
| train_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) | ||
| val_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) | ||
|
|
||
| # Write train data | ||
| train_data = [ | ||
| {"context": "What is 2+2?", "chosen": "4", "rejected": "5"}, | ||
| {"context": "What is 3*3?", "chosen": "9", "rejected": "6"}, | ||
| ] | ||
| for item in train_data: | ||
| lines = train_file.write(json.dumps(item) + "\n") | ||
| train_file.flush() | ||
|
|
||
| # Write validation data | ||
| val_data = [ | ||
| {"context": "What is 4+4?", "chosen": "8", "rejected": "7"}, | ||
| {"context": "What is 5*5?", "chosen": "25", "rejected": "20"}, | ||
| ] | ||
| for item in val_data: | ||
| lines = val_file.write(json.dumps(item) + "\n") | ||
| val_file.flush() | ||
|
|
||
| return train_file, val_file |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.