Skip to content

Refactor DPO#3906

Open
qgallouedec wants to merge 236 commits intomainfrom
refactor-dpo
Open

Refactor DPO#3906
qgallouedec wants to merge 236 commits intomainfrom
refactor-dpo

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Aug 15, 2025

What is this PR?

This pull request refactors the DPO Trainer.

Docs: https://moon-ci-docs.huggingface.co/docs/trl/pr_3906

Benchmark: #3906 (comment)

Closes #2563
Closes #3985
Closes #4071
Closes #2047

Important modifications

1. Remove the encoder-decoder support #

Remove encoder–decoder support to reduce code complexity and maintenance burden, focusing on decoder-only architectures which dominate current LLM usage and training workflows.

2. Remove RunningMoment from the pairwise BCO objective #

Section 4.2 of [Binary Classifier Optimization for Large Language Model Alignment] shows that alignment objectives must be invariant to adding constants to rewards, and the paper enforces this by formulating losses in terms of likelihood ratios and relative (baseline-subtracted) rewards. But this automatically satisfied in the preference case. So we don't need any running moment like here

if "bco_pair" in self.loss_type:
self.running = RunningMoments(self.accelerator)

and here

trl/trl/trainer/dpo_trainer.py

Lines 1133 to 1137 in e5503ea

self.running.update(rewards)
delta = self.running.mean
losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
)

Probably a mistake from its initial implementation in #1524

3. Rename "aot_pair" to "aot_unpaired" #

In the paper, we have:

  • "AOT paired" (for preference datasets), corresponding to loss_type="aot", and
  • "AOT unpaired" (for unpaired preference datasets), corresponding to loss_type="aot_pair"

For some reason, from what I understand of consistency with the late loss_type="kto_pair", the author initially called the later "aot_pair", even though it is the unpaired version, see #1701, which in my opinion is very misleading. I therefore propose to have

  • "AOT paired": loss_type="aot" and
  • "AOT unpaired": loss_type="aot_unpaired"

We will follow a minor version of deprecation.

4. Deprecate separate prompt/completion truncation #

DPOTrainer currently truncates prompts and completions separately using max_prompt_length and max_completion_length.

This is suboptimal: with a fixed total token budget, separate limits cannot adapt to varying prompt/completion lengths, causing unnecessary truncation.

prompts = [[1, 2, 3, 4, 5], [6, 7]]
completions = [[8, 9], [10, 11, 12, 13]]

Both samples fit within a 7-token budget.

Separate truncation fails

max_prompt_length = 5
max_completion_length = 2
# → [[1,  2,  3,  4,  5,  8,  9],
#    [6,  7, 10, 11]]   # truncated
max_prompt_length = 2
max_completion_length = 5
# → [[4,  5,  8,  9],   # truncated
#    [6,  7, 10, 11, 12, 13]]

No choice of (max_prompt_length, max_completion_length) preserves both samples.

Single-sequence truncation works

max_length = 7
# → [[1, 2, 3, 4, 5, 8, 9],
#    [6, 7, 10, 11, 12, 13]]

Separate prompt/completion truncation wastes token budget.
Truncating the concatenated sequence with max_length is always strictly better.

Recommendation: deprecate max_prompt_length and max_completion_length

5. Switch to default truncation side being "keep_start" #

In DPO, preference labels are defined over the conditional distribution $p(y|x)$ induced by the full prompt.
Left truncation (keeping the end) alters this conditioning context by removing system instructions or intent-setting tokens, so the model is trained on preferences that no longer correspond to the same conditional distribution, potentially invalidating or even reversing the preference signal.
Right truncation (keeping the start) preserves the conditioning distribution and task semantics; while it may weaken the signal by shortening completions, it does not change what the preference is conditioned on. Moreover, because chosen and rejected responses typically have different lengths, left truncation can remove a different number of tokens from each completion, introducing additional asymmetry and noise in the preference comparison.

Therefore I recommend setting truncation_side="keep_start" by default (instead of "keep_end")

6. Deprecate ref_model_init_kwargs #

To my knowledge, the reference model is largely initialized to be equal to the initial trained model. Therefore, initializing it differently is an advanced/uncommon use case, and it would be more logical and simpler in this case to let the user initialize the reference model themselves, then pass it to the trainer via the ref_model argument:

ref_model_id = ...  # str
ref_model_init_kwargs = {...}

# before
trainer = DPOTrainer(
    ...,
    ref_model=ref_model_id,
    args=DPOConfig(ref_model_init_kwargs=ref_model_init_kwargs),
)

# after
ref_model = AutoModelForCausalLM(ref_model_id, **ref_model_init_kwargs)
trainer = DPOTrainer(
    ...,
    ref_model=ref_model,
)

7. Deprecate generate_during_eval #

Note

I'm not entirely sure about this one yet.

In all trainers, we have a trained model model. It would seem simpler to use a callback, such as LogCompletionCallback, instead of generating it within the trainer. Having both LogCompletionCallback and generate_during_eval feels like a duplicate to me.

8. Deprecate force_use_ref_model #

Before, providing both a peft_config and a ref_model would result in an error:

You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there
is no need to pass a reference model. Please pass `ref_model=None` if you want to train
PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init
if you want to use a different reference model.

There are two issues with this behavior:

  • The error was only triggered when a peft_config was provided, but not when the model argument was already a PEFT model, leading to inconsistent behavior. More generally, using a different ref_model is an uncommon usage pattern in DPO, regardless of whether PEFT is used
  • If a user does provide a ref_model explicitly it is reasonable to assume that this is intentional. Rejecting this combination at the API level and having a dedicated argument for this case is therefore unnecessarily restrictive.

As a result, it is cleaner to always honor an explicitly provided ref_model, while documenting that passing a ref_model is not necessary in the vast majority of use cases.

9. Deprecated use_logits_to_keep #

Previously, we had the option to enable use_logits_to_keep and lm_head was only used on the last $N$ tokens, where $N$ was the largest completion length in the batch. It used to save VRAM, but needed a bit a complexity in the code:

trl/trl/trainer/dpo_trainer.py

Lines 1578 to 1615 in 1dc8bbc

if self.use_logits_to_keep:
# Compute logits_to_keep based on loss_mask pattern:
# [[0, 0, 0, x, x, x, x],
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
model_kwargs["logits_to_keep"] = logits_to_keep
model_kwargs["output_hidden_states"] = True
if self.padding_free:
# Flatten the input_ids, position_ids, and loss_mask
# input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]]
# [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]]
input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
model_kwargs["position_ids"] = position_ids
else:
model_kwargs["attention_mask"] = attention_mask
outputs = model(input_ids, **model_kwargs)
logits = outputs.logits
# Offset the logits by one to align with the labels
labels = torch.roll(input_ids, shifts=-1, dims=1)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()
if self.use_logits_to_keep:
# Align labels with logits
# logits: -, -, [x2, x3, x4, x5, x6]
# ^ --------- ^ after logits[:, :-1, :]
# labels: [y0, y1, y2, y3, y4, y5, y6]
# ^ --------- ^ with logits_to_keep=4, [:, -4:]
# loss_mask: [0, 0, 0, 1, 1, 1, 1]
labels = labels[:, -logits_to_keep:]
loss_mask = loss_mask[:, -logits_to_keep:]

We're working on something even more efficient: using the lm_head on completion tokens only, see internal discussion. This would be always activated.

10. Deprecate label_pad_token_id #

It's now standard everywhere to use -100. In my opinion, having a way to parametrize this value is not useful.

11. Deprecate FDivergenceType #

Before, the f_divergence_type could be provided either as an enum (FDivergenceType) or as a string. The enum adds unnecessary complexity and doesn't bring any real benefit. To stay consistent with loss_type, we should standardize on plain strings and deprecate the enum.

  from trl.trainer.dpo_config import DPOConfig, FDivergenceType

  config = DPOConfig(
-     f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE,
+     f_divergence_type="alpha_divergence",
  )

In the same spirit, I also removed FDivergenceConstants. I didn't add a deprecation path because this enum had no user-facing value: it was only used internally, and the implementation was unnecessarily complex. The new approach is both simpler and more readable, and matches what we now use everywhere else.

12. Deprecate DPOConfig.tools #

To better align with other trainers (SFT, GRPO, RLOO, Reward), we should remove the tools argument from DPOConfig, and instead provide tools per example in the dataset via a tools column consumed by the chat template.

Before:

training_args = DPOConfig(
    ...,
    tools=tools,
)
next(iter(dataset))  # {'chosen': 'Some response', 'rejected': 'Another response'}

After:

training_args = DPOConfig(
    ...,
)
next(iter(dataset))  # {'chosen': 'Some response', 'rejected': 'Another response', 'tools': tools}

13. Deprecate reference_free #

The reference_free argument is redundant with the CPO trainer in my understanding. It introduces a lot of special cases in the code, making it more complex to maintain. Plus, I see no codebase that uses it. Consequently, I suggest deprecating it to simplify the codebase. Users wanting to do reference-free DPO can use CPO instead.

14. Deprecate base_model_attribute_name #

In Liger, we need to retrieve the underlying base model. Today, base_model_attribute_name is only used as a fallback when get_decoder is unavailable or returns None.

15. Deprecate model_adapter_name and ref_adapter_name #

These arguments were originally meant to select which PEFT adapter to use for the training model and the reference model—mainly for setups where a single model might contain multiple adapters.

In practice, that complexity isn't needed when resuming from a pretrained adapter:

  • When you resume training, you only need to load the adapter you're going to train.
  • In the very rare case where you truly need two adapters (one for the train model and one for the reference model), you still just load those explicitly—and you control their names at load time. So passing adapter-name arguments through the trainer doesn't add value.

Instead, the recommended flow is:

  • Pass the PEFT-wrapped model directly as model (keep the default name "default")
  • The trainer will internally create the reference model by copying the adapter (and name it "ref")
  • If you really need a different adapter for the reference model, load it yourself and pass it as ref_model

To keep behavior consistent, training assumes a single adapter named "default". That means custom adapter names are no longer supported, and model_adapter_name / ref_adapter_name become unnecessary. Deprecating them removes redundant configuration and reduces confusion about how the reference model is produced.

16. Change the default value of f_alpha_divergence_coef from 1.0 to 0.5 #

We propose changing the default value of f_alpha_divergence_coef from 1.0 to 0.5. In the paper, the authors specify that α should lie in (0, 1), so α = 1 is excluded from the theoretically supported setting. Moreover, the α → 1 limit corresponds to the forward KL boundary case, which is already explicitly available via f_divergence_type="forward_kl" when that behavior is desired. In contrast, α = 0.5 sits well inside the valid interval and provides a more balanced trade-off between mode-seeking and mass-covering behavior, making it a safer and more generally robust default.

@qgallouedec qgallouedec mentioned this pull request Oct 30, 2025
47 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Possible bug in tutorial Feature Request: Save/Load Precomputed Ref Log-Probabilities in DPOTrainer dpo_vlm.py Always allow ref_model=None

4 participants

Comments