Skip to content
Merged
4 changes: 2 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@
- sections: # Sorted alphabetically
- local: dpo_trainer
title: DPO
- local: online_dpo_trainer
title: Online DPO
- local: grpo_trainer
title: GRPO
- local: kto_trainer
Expand Down Expand Up @@ -111,6 +109,8 @@
title: MiniLLM
- local: nash_md_trainer
title: Nash-MD
- local: online_dpo_trainer
title: Online DPO
- local: orpo_trainer
title: ORPO
- local: papo_trainer
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,14 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.online_dpo.OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling |
| [`experimental.prm.PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
Expand Down
4 changes: 2 additions & 2 deletions docs/source/example_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
| [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM |
| [`examples/scripts/openenv/catch.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/catch.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Catch environment (OpenSpiel) and vLLM |
| [`examples/scripts/openenv/echo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Echo environment and vLLM. |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL

- [`GRPOTrainer`] ⚡️
- [`RLOOTrainer`] ⚡️
- [`OnlineDPOTrainer`] ⚡️
- [`experimental.nash_md.NashMDTrainer`] 🧪 ⚡️
- [`experimental.online_dpo.OnlineDPOTrainer`] 🧪 ⚡️
- [`experimental.ppo.PPOTrainer`] 🧪
- [`experimental.xpo.XPOTrainer`] 🧪 ⚡️

Expand Down
12 changes: 6 additions & 6 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ Below is the script to train the model:
```python
# train_online_dpo.py
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer
from trl.experimental.judges import PairRMJudge
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
Expand Down Expand Up @@ -66,7 +66,7 @@ The best programming language depends on your specific needs and priorities. Som

## Expected dataset type

Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`experimental.online_dpo.OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

Expand All @@ -93,7 +93,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht

### Encourage EOS token generation

When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.online_dpo.OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.online_dpo.OnlineDPOConfig`]:

```python
training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
Expand Down Expand Up @@ -147,7 +147,7 @@ While training and evaluating, we record the following reward metrics. Here is a
* `logps/chosen`: The mean log probabilities of the chosen completions.
* `logps/rejected`: The mean log probabilities of the rejected completions.
* `val/contain_eos_token`: The fraction of completions which contain an EOS token.
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`OnlineDPOConfig`].
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.online_dpo.OnlineDPOConfig`].

## Benchmark experiments

Expand Down Expand Up @@ -261,11 +261,11 @@ The online DPO checkpoint gets increasingly more win rate as we scale up the mod

## OnlineDPOTrainer

[[autodoc]] OnlineDPOTrainer
[[autodoc]] experimental.online_dpo.OnlineDPOTrainer
- train
- save_model
- push_to_hub

## OnlineDPOConfig

[[autodoc]] OnlineDPOConfig
[[autodoc]] experimental.online_dpo.OnlineDPOConfig
2 changes: 1 addition & 1 deletion docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ training_args = GRPOConfig(..., ds3_gather_for_generation=False)
<hfoption id="Online DPO">

```python
from trl import OnlineDPOConfig
from trl.experimental.online_dpo import OnlineDPOConfig

training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/speeding_up_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install trl[vllm]
Then, enable it by passing `use_vllm=True` in the training arguments.

```python
from trl import OnlineDPOConfig
from trl.experimental.online_dpo import OnlineDPOConfig

training_args = OnlineDPOConfig(..., use_vllm=True)
```
Expand Down
8 changes: 4 additions & 4 deletions docs/source/vllm_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ This document will guide you through the process of using vLLM with TRL for fast
> The following trainers currently support generation with vLLM:
>
> - [`GRPOTrainer`]
> - [`OnlineDPOTrainer`]
> - [`RLOOTrainer`]
> - [`experimental.nash_md.NashMDTrainer`]
> - [`experimental.online_dpo.OnlineDPOTrainer`]
> - [`experimental.xpo.XPOTrainer`]

## 🚀 How can I use vLLM with TRL to speed up training?
Expand Down Expand Up @@ -65,7 +65,7 @@ trainer.train()

```python
from datasets import load_dataset
from trl import OnlineDPOTrainer, OnlineDPOConfig
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
Expand Down Expand Up @@ -316,7 +316,7 @@ training_args = GRPOConfig(
<hfoption id="OnlineDPO">

```python
from trl import OnlineDPOConfig
from trl.experimental.online_dpo import OnlineDPOConfig

training_args = OnlineDPOConfig(
...,
Expand Down Expand Up @@ -391,7 +391,7 @@ training_args = GRPOConfig(
<hfoption id="OnlineDPO">

```python
from trl import OnlineDPOConfig
from trl.experimental.online_dpo import OnlineDPOConfig

training_args = OnlineDPOConfig(
...,
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,14 @@
from trl import (
LogCompletionsCallback,
ModelConfig,
OnlineDPOConfig,
OnlineDPOTrainer,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer


# Enable logging in a Hugging Face Space
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/online_dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@
from trl import (
LogCompletionsCallback,
ModelConfig,
OnlineDPOConfig,
OnlineDPOTrainer,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
from trl.rewards import accuracy_reward, think_format_reward


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import is_peft_available, is_vision_available

from trl import OnlineDPOConfig, OnlineDPOTrainer
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer

from .testing_utils import (
from ..testing_utils import (
RandomPairwiseJudge,
TrlTestCase,
require_llm_blender,
Expand Down
4 changes: 2 additions & 2 deletions trl/experimental/nash_md/nash_md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

from dataclasses import dataclass, field

from ...trainer.online_dpo_config import OnlineDPOConfig
from ..online_dpo import OnlineDPOConfig


@dataclass
class NashMDConfig(OnlineDPOConfig):
r"""
Configuration class for the [`experimental.nash_md.NashMDTrainer`].

Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
Subclass of [`experimental.online_dpo.OnlineDPOConfig`] we can use all its arguments and add the following:

Parameters:
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
Expand Down
4 changes: 2 additions & 2 deletions trl/experimental/nash_md/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from ...models.modeling_base import GeometricMixtureWrapper
from ...models.utils import unwrap_model_for_generation
from ...trainer.judges import BasePairwiseJudge
from ...trainer.online_dpo_trainer import OnlineDPOTrainer
from ...trainer.utils import SIMPLE_CHAT_TEMPLATE, empty_cache, get_reward, selective_log_softmax, truncate_right
from ..online_dpo import OnlineDPOTrainer
from .nash_md_config import NashMDConfig


Expand All @@ -50,7 +50,7 @@ class NashMDTrainer(OnlineDPOTrainer):
"""
Trainer for the Nash-MD method.

It is implemented as a subclass of [`OnlineDPOTrainer`].
It is implemented as a subclass of [`experimental.online_dpo.OnlineDPOTrainer`].

Args:
model ([`~transformers.PreTrainedModel`]):
Expand Down
19 changes: 19 additions & 0 deletions trl/experimental/online_dpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .online_dpo_config import OnlineDPOConfig
from .online_dpo_trainer import OnlineDPOTrainer


__all__ = ["OnlineDPOConfig", "OnlineDPOTrainer"]
Loading
Loading