From 2d2306a55c1a4f6909e3086aa1d7265c691d9944 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Wed, 5 Nov 2025 16:59:07 -0800 Subject: [PATCH 1/8] Move ORPOTrainer and ORPOConfig to experimental - Move ORPOTrainer and ORPOConfig to trl.experimental.orpo - Add deprecation warnings in trl.trainer with removal planned for TRL 0.29.0 - Update imports in tests, examples, and documentation - Maintain backward compatibility through deprecation stubs Fixes #4465 --- docs/source/_toctree.yml | 8 +- docs/source/orpo_trainer.md | 2 +- examples/scripts/orpo.py | 3 +- tests/test_orpo_trainer.py | 2 +- trl/experimental/orpo/__init__.py | 19 + trl/experimental/orpo/orpo_config.py | 169 ++++ trl/experimental/orpo/orpo_trainer.py | 1050 +++++++++++++++++++++++++ trl/trainer/__init__.py | 4 - trl/trainer/orpo_config.py | 162 +--- trl/trainer/orpo_trainer.py | 1009 +----------------------- 10 files changed, 1290 insertions(+), 1138 deletions(-) create mode 100644 trl/experimental/orpo/__init__.py create mode 100644 trl/experimental/orpo/orpo_config.py create mode 100644 trl/experimental/orpo/orpo_trainer.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 74f79544b33..a7555a1e7ed 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -70,8 +70,6 @@ title: KTO - local: nash_md_trainer title: Nash-MD - - local: orpo_trainer - title: ORPO - local: ppo_trainer title: PPO - local: prm_trainer @@ -117,8 +115,10 @@ title: GRPO With Replay Buffer - local: gspo_token title: GSPO-token - - local: papo_trainer - title: PAPO - local: openenv title: OpenEnv Integration + - local: orpo_trainer + title: ORPO + - local: papo_trainer + title: PAPO title: Experimental \ No newline at end of file diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index 3092de2d2ae..d3428313bb8 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -34,7 +34,7 @@ Below is the script to train the model: ```python # train_orpo.py from datasets import load_dataset -from trl import ORPOConfig, ORPOTrainer +from trl.experimental.orpo import ORPOConfig, ORPOTrainer from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index e256a4277ad..36eb32a565b 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -63,7 +63,8 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config +from trl import ModelConfig, ScriptArguments, get_peft_config +from trl.experimental.orpo import ORPOConfig, ORPOTrainer # Enable logging in a Hugging Face Space diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 70f087ac948..554938293ca 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -17,7 +17,7 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from trl import ORPOConfig, ORPOTrainer +from trl.experimental.orpo import ORPOConfig, ORPOTrainer from .testing_utils import TrlTestCase, require_peft diff --git a/trl/experimental/orpo/__init__.py b/trl/experimental/orpo/__init__.py new file mode 100644 index 00000000000..17960ce5b18 --- /dev/null +++ b/trl/experimental/orpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .orpo_config import ORPOConfig +from .orpo_trainer import ORPOTrainer + + +__all__ = ["ORPOConfig", "ORPOTrainer"] diff --git a/trl/experimental/orpo/orpo_config.py b/trl/experimental/orpo/orpo_config.py new file mode 100644 index 00000000000..523beeab934 --- /dev/null +++ b/trl/experimental/orpo/orpo_config.py @@ -0,0 +1,169 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class ORPOConfig(TrainingArguments): + r""" + Configuration class for the [`ORPOTrainer`]. + + This class includes only the parameters that are specific to ORPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. This argument is required if you want to use the default data collator. + max_completion_length (`int`, *optional*): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the relative ratio loss weight in the ORPO loss. In the + [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the + [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, *optional*): + Padding value to use. If `None`, the padding value of the tokenizer is used. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from the model to W&B or Comet during evaluation. + is_encoder_decoder (`bool`, *optional*): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_prompt_length: int | None = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + max_completion_length: int | None = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is " + "denoted by λ." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + label_pad_token_id: int = field( + default=-100, + metadata={ + "help": "Label pad token id. This argument is required if you want to use the default data collator." + }, + ) + padding_value: int | None = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, + ) + is_encoder_decoder: bool | None = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " + "argument, you need to specify if the model returned by the callable is an encoder-decoder model." + }, + ) + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/trl/experimental/orpo/orpo_trainer.py b/trl/experimental/orpo/orpo_trainer.py new file mode 100644 index 00000000000..8ac3b0b6ed2 --- /dev/null +++ b/trl/experimental/orpo/orpo_trainer.py @@ -0,0 +1,1050 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import random +import textwrap +import warnings +from collections import defaultdict +from collections.abc import Callable +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState, logging +from datasets import Dataset +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + is_comet_available, + is_torch_xla_available, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_peft_available, is_torch_fx_proxy + +from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt +from ..base_trainer import BaseTrainer +from .orpo_config import ORPOConfig +from ..utils import ( + DPODataCollatorWithPadding, + add_bos_token_if_needed, + add_eos_token_if_needed, + disable_dropout_in_model, + log_table_to_comet_experiment, + pad_to_length, + peft_module_casting_to_bf16, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +logger = logging.get_logger(__name__) + + +class ORPOTrainer(BaseTrainer): + r""" + Initialize ORPOTrainer. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. + args ([`ORPOConfig`]): + The ORPO config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the + sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be + used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in + a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + """ + + _tag_names = ["trl", "orpo"] + _name = "ORPO" + _paper = { + "title": "ORPO: Monolithic Preference Optimization without Reference Model", + "id": "2403.07691", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{hong2024orpo, + title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, + author = {Jiwoo Hong and Noah Lee and James Thorne}, + year = 2024, + eprint = {arXiv:2403.07691} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + args: ORPOConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + model_init: Callable[[], PreTrainedModel] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, + ): + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + dtype = model_init_kwargs.get("dtype") + if dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." + ) + model_init_kwargs["dtype"] = dtype + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + if self.is_encoder_decoder: + self.decoder_start_token_id = model.config.decoder_start_token_id + self.pad_token_id = model.config.pad_token_id + + if processing_class is None: + raise ValueError("processing_class must be specified to tokenize a ORPO dataset.") + if args.max_length is None: + logger.warning( + "`max_length` is not set in the ORPOConfig's init" + " it will default to `512` by default, but you should do it yourself in the future.", + ) + max_length = 512 + else: + max_length = args.max_length + if args.max_prompt_length is None: + logger.warning( + "`max_prompt_length` is not set in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + ) + max_prompt_length = 128 + else: + max_prompt_length = args.max_prompt_length + + if args.max_completion_length is None and self.is_encoder_decoder: + logger.warning( + "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + ) + self.max_completion_length = 128 + else: + self.max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + label_pad_token_id=args.label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + logger.warning( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = args.truncation_mode + self.processing_class = processing_class + + self.beta = args.beta + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + logger.warning( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + ) + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed, and apply the chat template if needed + train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + + b)[len(enc(a)):]`. Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict: + """Tokenize a single row from a ORPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, + we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length + of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.processing_class(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + a != b + for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True) + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + # add BOS token to head of prompt. Avoid adding if it's already there + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.processing_class.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens, + ) + + # add EOS token to end of answer. Avoid adding if it's already there + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.processing_class.eos_token_id, chosen_tokens, rejected_tokens + ) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.processing_class( + chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + rejected_tokens = self.processing_class( + rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + prompt_tokens = self.processing_class( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["rejected_labels"]) + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["chosen_labels"]) + ) + + if is_torch_xla_available(): + # Pad the sequences to global max_length to avoid TorchXLA recompilation + for k in batch: + if "labels" in k or self.is_encoder_decoder: + pad_value = self.label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = self.padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) + return batch + + @staticmethod + def concatenated_inputs( + batch: dict[str, list | torch.LongTensor], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: torch.device | None = None, + ) -> dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: + A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors + of shape (batch_size, sequence_length). + is_encoder_decoder: + Whether the model is an encoder-decoder model. + label_pad_token_id: + The label pad token id. + padding_value: + The padding value to use for the concatenated inputs_ids. + device: + The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def odds_ratio_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: + Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: + Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO + loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for + the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the + rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes. + """ + + # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + losses = self.beta * ratio + + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: + Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are + ignored. Shape: (batch_size, sequence_length) + average_log_prob: + If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the + log probabilities of the (non-masked) tokens. + label_pad_token_id: The label pad token id. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the + given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == label_pad_token_id, 0, labels) + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + # orpo chosen nll loss is computed over the full prompt and response + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1, :] + rejected_logits = all_logits[len_chosen:, :-1, :] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, list | torch.LongTensor], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( + chosen_rewards - rejected_rewards + ).mean() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() + metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( + policy_rejected_logits.detach().mean() + ).mean() + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( + policy_chosen_logits.detach().mean() + ).mean() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() + metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() + metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() + if is_torch_xla_available(): + xm.mark_step() # needed because .item() calls + for k, v in metrics.items(): + metrics[k] = v.item() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs=False, + num_items_in_batch=None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + return policy_output_decoded + + def prediction_step( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ): + if not self.use_dpo_data_collator: + logger.warning( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: bool | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by + `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded = self.generate_from_model(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy"], + data=[ + [prompt, pol[len(prompt) :]] + for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float`, *optional*): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + def _shift_right(self, input_ids): + if self.decoder_start_token_id is None: + raise ValueError( + "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + if self.pad_token_id is None: + raise ValueError("model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) + + return shifted_input_ids + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 98846bf7159..1f2e3ccec00 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -54,8 +54,6 @@ "nash_md_trainer": ["NashMDTrainer"], "online_dpo_config": ["OnlineDPOConfig"], "online_dpo_trainer": ["OnlineDPOTrainer"], - "orpo_config": ["ORPOConfig"], - "orpo_trainer": ["ORPOTrainer"], "ppo_config": ["PPOConfig"], "ppo_trainer": ["PPOTrainer"], "prm_config": ["PRMConfig"], @@ -114,8 +112,6 @@ from .nash_md_trainer import NashMDTrainer from .online_dpo_config import OnlineDPOConfig from .online_dpo_trainer import OnlineDPOTrainer - from .orpo_config import ORPOConfig - from .orpo_trainer import ORPOTrainer from .ppo_config import PPOConfig from .ppo_trainer import PPOTrainer from .prm_config import PRMConfig diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index 523beeab934..73d2345e88c 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -12,158 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field -from typing import Any +import warnings +from dataclasses import dataclass -from transformers import TrainingArguments +from ..experimental.orpo import ORPOConfig as ExperimentalORPOConfig @dataclass -class ORPOConfig(TrainingArguments): +class ORPOConfig(ExperimentalORPOConfig): r""" Configuration class for the [`ORPOTrainer`]. - This class includes only the parameters that are specific to ORPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. + - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. + This class has been moved to `trl.experimental.orpo.ORPOConfig` and will be removed in TRL 0.29.0. + Please update your imports: - Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int`, *optional*): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the relative ratio loss weight in the ORPO loss. In the - [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the - [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int`, *optional*): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from the model to W&B or Comet during evaluation. - is_encoder_decoder (`bool`, *optional*): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - """ - - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + ```python + from trl.experimental.orpo import ORPOConfig + ``` - # Parameters whose default values are overridden from TrainingArguments - learning_rate: float = field( - default=1e-6, - metadata={"help": "The initial learning rate for AdamW."}, - ) - logging_steps: float = field( - default=10, - metadata={ - "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " - "will be interpreted as ratio of total training steps." - }, - ) - gradient_checkpointing: bool = field( - default=True, - metadata={ - "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." - }, - ) - bf16: bool | None = field( - default=None, - metadata={ - "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " - "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " - "`fp16` is not set." - }, - ) + For more details, see: https://github.com/huggingface/trl/issues/4223 - max_length: int | None = field( - default=1024, - metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, - ) - max_prompt_length: int | None = field( - default=512, - metadata={ - "help": "Maximum length of the prompt. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, - ) - max_completion_length: int | None = field( - default=None, - metadata={ - "help": "Maximum length of the completion. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, - ) - beta: float = field( - default=0.1, - metadata={ - "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is " - "denoted by λ." - }, - ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether to disable dropout in the model."}, - ) - label_pad_token_id: int = field( - default=-100, - metadata={ - "help": "Label pad token id. This argument is required if you want to use the default data collator." - }, - ) - padding_value: int | None = field( - default=None, - metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, - ) - truncation_mode: str = field( - default="keep_end", - metadata={ - "help": "Truncation mode to use when the prompt is too long.", - "choices": ["keep_end", "keep_start"], - }, - ) - generate_during_eval: bool = field( - default=False, - metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, - ) - is_encoder_decoder: bool | None = field( - default=None, - metadata={ - "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " - "argument, you need to specify if the model returned by the callable is an encoder-decoder model." - }, - ) - model_init_kwargs: dict[str, Any] | None = field( - default=None, - metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " - "from a string." - }, - ) - dataset_num_proc: int | None = field( - default=None, - metadata={"help": "Number of processes to use for processing the dataset."}, - ) + + """ def __post_init__(self): - self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 - + warnings.warn( + "ORPOConfig has been moved to trl.experimental.orpo.ORPOConfig and will be removed from " + "trl.trainer in TRL 0.29.0. Please update your imports to: " + "`from trl.experimental.orpo import ORPOConfig`. " + "For more details, see: https://github.com/huggingface/trl/issues/4223", + FutureWarning, + stacklevel=2, + ) super().__post_init__() diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index fb905243800..1c60e06c5f5 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -12,124 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import os -import random -import textwrap import warnings -from collections import defaultdict from collections.abc import Callable -from contextlib import nullcontext -from pathlib import Path -from typing import Any, Literal +from typing import Any -import numpy as np -import pandas as pd -import torch import torch.nn as nn -import torch.nn.functional as F -from accelerate import PartialState, logging from datasets import Dataset -from torch import autocast -from torch.utils.data import DataLoader from transformers import ( - AutoModelForCausalLM, BaseImageProcessor, DataCollator, FeatureExtractionMixin, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, - is_comet_available, - is_torch_xla_available, - is_wandb_available, ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_fx_proxy -from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt -from .base_trainer import BaseTrainer +from ..experimental.orpo import ORPOTrainer as ExperimentalORPOTrainer from .orpo_config import ORPOConfig -from .utils import ( - DPODataCollatorWithPadding, - add_bos_token_if_needed, - add_eos_token_if_needed, - disable_dropout_in_model, - log_table_to_comet_experiment, - pad_to_length, - peft_module_casting_to_bf16, - selective_log_softmax, -) - - -if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training -if is_wandb_available(): - import wandb +class ORPOTrainer(ExperimentalORPOTrainer): + """ + Initialize ORPOTrainer. -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm + + This class has been moved to `trl.experimental.orpo.ORPOTrainer` and will be removed in TRL 0.29.0. + Please update your imports: -logger = logging.get_logger(__name__) + ```python + from trl.experimental.orpo import ORPOTrainer + ``` + For more details, see: https://github.com/huggingface/trl/issues/4223 -class ORPOTrainer(BaseTrainer): - r""" - Initialize ORPOTrainer. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. - args ([`ORPOConfig`]): - The ORPO config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. + """ - _tag_names = ["trl", "orpo"] - _name = "ORPO" - _paper = { - "title": "ORPO: Monolithic Preference Optimization without Reference Model", - "id": "2403.07691", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{hong2024orpo, - title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, - author = {Jiwoo Hong and Noah Lee and James Thorne}, - year = 2024, - eprint = {arXiv:2403.07691} - }"""), - } - def __init__( self, model: PreTrainedModel | nn.Module | str | None = None, @@ -149,207 +70,14 @@ def __init__( peft_config: dict | None = None, compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if args.model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - dtype = model_init_kwargs.get("dtype") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - model_init_kwargs["dtype"] = dtype - - if isinstance(model, str): - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # get peft model with the given config - model = get_peft_model(model, peft_config) - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): - raise ValueError( - "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve." - ) - - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder - - if self.is_encoder_decoder: - self.decoder_start_token_id = model.config.decoder_start_token_id - self.pad_token_id = model.config.pad_token_id - - if processing_class is None: - raise ValueError("processing_class must be specified to tokenize a ORPO dataset.") - if args.max_length is None: - logger.warning( - "`max_length` is not set in the ORPOConfig's init" - " it will default to `512` by default, but you should do it yourself in the future.", - ) - max_length = 512 - else: - max_length = args.max_length - if args.max_prompt_length is None: - logger.warning( - "`max_prompt_length` is not set in the ORPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - else: - max_prompt_length = args.max_prompt_length - - if args.max_completion_length is None and self.is_encoder_decoder: - logger.warning( - "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - self.max_completion_length = 128 - else: - self.max_completion_length = args.max_completion_length - - if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - logger.warning( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" - " we have set it for you, but you should do it yourself in the future.", - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - - self.max_length = max_length - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.processing_class = processing_class - - self.beta = args.beta - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - logger.warning( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - ) - - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and - # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens - # of the input, floating-point operations will not be computed." To suppress this warning, we set the - # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate - # that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().main_process_first(): - # Extract the prompt if needed, and apply the chat template if needed - train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - train_dataset = train_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc - ) - train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - if eval_dataset is not None: - eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - ) - eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - + warnings.warn( + "ORPOTrainer has been moved to trl.experimental.orpo.ORPOTrainer and will be removed from " + "trl.trainer in TRL 0.29.0. Please update your imports to: " + "`from trl.experimental.orpo import ORPOTrainer`. " + "For more details, see: https://github.com/huggingface/trl/issues/4223", + FutureWarning, + stacklevel=2, + ) super().__init__( model=model, args=args, @@ -358,700 +86,9 @@ def make_inputs_require_grad(module, input, output): eval_dataset=eval_dataset, processing_class=processing_class, model_init=model_init, - compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, + peft_config=peft_config, + compute_metrics=compute_metrics, ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - def build_tokenized_answer(self, prompt, answer): - """ - Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + - b)[len(enc(a)):]`. Reference: - https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - """ - - full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) - prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] - - answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] - answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] - - # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` - full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) - - # Prepare input tokens for token by token comparison - full_input_ids = np.array(full_tokenized["input_ids"]) - - if len(full_input_ids) != len(full_concat_input_ids): - raise ValueError("Prompt input ids and answer input ids should have the same length.") - - # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens - # can be merged together when tokenizing prompt+answer. This could result - # on the last token from the prompt being different when tokenized on its own - # vs when done as prompt+answer. - response_token_ids_start_idx = len(prompt_input_ids) - - # If tokenized prompt is different than both prompt+answer, then it means the - # last token has changed due to merging. - if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: - response_token_ids_start_idx -= 1 - - prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] - prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] - - if len(prompt_input_ids) != len(prompt_attention_mask): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] - answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] - - return dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - input_ids=answer_input_ids, - attention_mask=answer_attention_mask, - ) - - def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict: - """Tokenize a single row from a ORPO specific dataset. - - At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + - chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, - we truncate the chosen/rejected. - - We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length - of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. - """ - batch = {} - prompt = feature["prompt"] - chosen = feature["chosen"] - rejected = feature["rejected"] - - if not self.is_encoder_decoder: - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") - prompt_tokens = self.processing_class(prompt, add_special_tokens=False) - prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} - - if not isinstance(chosen, str): - raise ValueError(f"chosen should be an str but got {type(chosen)}") - chosen_tokens = self.build_tokenized_answer(prompt, chosen) - - if not isinstance(rejected, str): - raise ValueError(f"rejected should be an str but got {type(rejected)}") - rejected_tokens = self.build_tokenized_answer(prompt, rejected) - - # Last prompt token might get merged by tokenizer and - # it should not be included for generation if that happens - prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) - - chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) - rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) - prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) - - for k, v in prompt_tokens.items(): - prompt_tokens[k] = v[:prompt_len_input_ids] - - # Make sure prompts only have one different token at most an - # and length only differs by 1 at most - num_diff_tokens = sum( - a != b - for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True) - ) - num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) - if num_diff_tokens > 1 or num_diff_len > 1: - raise ValueError( - "Chosen and rejected prompt_input_ids might only differ on the " - "last token due to tokenizer merge ops." - ) - - # add BOS token to head of prompt. Avoid adding if it's already there - prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( - self.processing_class.bos_token_id, - prompt_len_input_ids, - prompt_tokens, - chosen_prompt_len_input_ids, - chosen_tokens, - rejected_prompt_len_input_ids, - rejected_tokens, - ) - - # add EOS token to end of answer. Avoid adding if it's already there - chosen_tokens, rejected_tokens = add_eos_token_if_needed( - self.processing_class.eos_token_id, chosen_tokens, rejected_tokens - ) - - longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) - - # if combined sequence is too long, truncate the prompt - for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - if self.truncation_mode == "keep_start": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] - elif self.truncation_mode == "keep_end": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] - else: - raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") - - # if that's still too long, truncate the response - for answer_tokens in [chosen_tokens, rejected_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - for k in ["input_ids", "attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] - - # Create labels - chosen_sequence_tokens = { - k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] - } - rejected_sequence_tokens = { - k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] - } - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] - chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(chosen_tokens["prompt_input_ids"]) - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] - rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(rejected_tokens["prompt_input_ids"]) - - for k, toks in { - "chosen_": chosen_sequence_tokens, - "rejected_": rejected_sequence_tokens, - "": prompt_tokens, - }.items(): - for type_key, tokens in toks.items(): - if type_key == "token_type_ids": - continue - batch[f"{k}{type_key}"] = tokens - - else: - chosen_tokens = self.processing_class( - chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - rejected_tokens = self.processing_class( - rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - prompt_tokens = self.processing_class( - prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True - ) - - batch["chosen_labels"] = chosen_tokens["input_ids"] - batch["rejected_labels"] = rejected_tokens["input_ids"] - batch["prompt_input_ids"] = prompt_tokens["input_ids"] - batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] - - if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): - batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["rejected_labels"]) - ) - batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["chosen_labels"]) - ) - - if is_torch_xla_available(): - # Pad the sequences to global max_length to avoid TorchXLA recompilation - for k in batch: - if "labels" in k or self.is_encoder_decoder: - pad_value = self.label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = self.padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) - return batch - - @staticmethod - def concatenated_inputs( - batch: dict[str, list | torch.LongTensor], - is_encoder_decoder: bool = False, - label_pad_token_id: int = -100, - padding_value: int = 0, - device: torch.device | None = None, - ) -> dict[str, torch.LongTensor]: - """Concatenate the chosen and rejected inputs into a single tensor. - - Args: - batch: - A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors - of shape (batch_size, sequence_length). - is_encoder_decoder: - Whether the model is an encoder-decoder model. - label_pad_token_id: - The label pad token id. - padding_value: - The padding value to use for the concatenated inputs_ids. - device: - The device for the concatenated inputs. - - Returns: - A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. - """ - concatenated_batch = {} - - if is_encoder_decoder: - max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) - else: - max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - - for k in batch: - if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("chosen", "concatenated") - concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) - for k in batch: - if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("rejected", "concatenated") - concatenated_batch[concatenated_key] = torch.cat( - ( - concatenated_batch[concatenated_key], - pad_to_length(batch[k], max_length, pad_value=pad_value), - ), - dim=0, - ).to(device=device) - - if is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) - concatenated_batch["concatenated_attention_mask"] = ( - batch["prompt_attention_mask"].repeat(2, 1).to(device=device) - ) - - return concatenated_batch - - def odds_ratio_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: - Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: - Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO - loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for - the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the - rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes. - """ - - # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) - log_odds = (policy_chosen_logps - policy_rejected_logps) - ( - torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) - ) - ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio - - chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() - rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() - - return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) - - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: - Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are - ignored. Shape: (batch_size, sequence_length) - average_log_prob: - If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the - log probabilities of the (non-masked) tokens. - label_pad_token_id: The label pad token id. - is_encoder_decoder: Whether the model is an encoder-decoder model. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the - given logits. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels = torch.where(labels == label_pad_token_id, 0, labels) - - per_token_logps = selective_log_softmax(logits, labels) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def concatenated_forward( - self, model: nn.Module, batch: dict[str, list | torch.LongTensor] - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - """ - concatenated_batch = self.concatenated_inputs( - batch, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - padding_value=self.padding_value, - device=self.accelerator.device, - ) - len_chosen = batch["chosen_labels"].shape[0] - - model_kwargs = ( - { - "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), - } - if self.is_encoder_decoder - else {} - ) - - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) - all_logits = outputs.logits - - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - if self.is_encoder_decoder: - labels = concatenated_batch["concatenated_labels"].clone() - else: - labels = concatenated_batch["concatenated_input_ids"].clone() - attention_mask = concatenated_batch["concatenated_attention_mask"] - labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) - # orpo chosen nll loss is computed over the full prompt and response - chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - - all_logps = self.get_batch_logps( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=True, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - if not self.is_encoder_decoder: - chosen_logits = all_logits[:len_chosen, :-1, :] - rejected_logits = all_logits[len_chosen:, :-1, :] - else: - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] - - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) - - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) - - def get_batch_loss_metrics( - self, - model, - batch: dict[str, list | torch.LongTensor], - train_eval: Literal["train", "eval"] = "train", - ): - """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - - forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( - policy_chosen_logps, policy_rejected_logps - ) - # full ORPO loss - loss = policy_nll_loss - losses.mean() - - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() - metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( - chosen_rewards - rejected_rewards - ).mean() - metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() - metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() - metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( - policy_rejected_logits.detach().mean() - ).mean() - metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( - policy_chosen_logits.detach().mean() - ).mean() - metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() - metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() - metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() - if is_torch_xla_available(): - xm.mark_step() # needed because .item() calls - for k, v in metrics.items(): - metrics[k] = v.item() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss - - return loss, metrics - - def compute_loss( - self, - model: PreTrainedModel | nn.Module, - inputs: dict[str, torch.Tensor | Any], - return_outputs=False, - num_items_in_batch=None, - ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") - - # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: - loss = loss.to(self.args.device) - - # force log the metrics - self.store_metrics(metrics, train_eval="train") - - if return_outputs: - return (loss, metrics) - return loss - - def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - return policy_output_decoded - - def prediction_step( - self, - model: PreTrainedModel | nn.Module, - inputs: dict[str, torch.Tensor | Any], - prediction_loss_only: bool, - ignore_keys: list[str] | None = None, - ): - if not self.use_dpo_data_collator: - logger.warning( - "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " - "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" - ) - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") - - # force log the metrics - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return (loss.detach(), None, None) - - # logits for the chosen and rejected samples from model - logits_dict = { - "eval_logits/chosen": metrics["eval_logits/chosen"], - "eval_logits/rejected": metrics["eval_logits/rejected"], - } - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: bool | None = None, - ignore_keys: list[str] | None = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded = self.generate_from_model(self.model, random_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy"], - data=[ - [prompt, pol[len(prompt) :]] - for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True) - ], - ) - if "wandb" in self.args.report_to: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output - - def log(self, logs: dict[str, float], start_time: float | None = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float`, *optional*): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) - - def _shift_right(self, input_ids): - if self.decoder_start_token_id is None: - raise ValueError( - "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) - shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = self.decoder_start_token_id - - if self.pad_token_id is None: - raise ValueError("model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) - - return shifted_input_ids - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) From 18040a40375720aec6f7d57aaa7eae7e8373482b Mon Sep 17 00:00:00 2001 From: Behrooz Date: Wed, 5 Nov 2025 18:36:39 -0800 Subject: [PATCH 2/8] Address reviewer feedback on ORPO experimental migration - Restore ORPO imports in trl/trainer/__init__.py for backward compatibility - Fix deprecation stub naming from ExperimentalORPOTrainer to _ORPOTrainer - Add torch import to deprecation stub for type hints - Fix relative import paths in trl/experimental/orpo/orpo_trainer.py - Update autodoc references to experimental.orpo.ORPOTrainer - Update all documentation references to use experimental namespace - Move ORPO test from test_trainers_args.py to experimental/test_trainers_args.py --- docs/source/community_tutorials.md | 2 +- docs/source/dataset_formats.md | 2 +- docs/source/example_overview.md | 2 +- docs/source/index.md | 2 +- docs/source/orpo_trainer.md | 8 ++--- tests/{ => experimental}/test_orpo_trainer.py | 0 tests/experimental/test_trainers_args.py | 28 ++++++++++++++++++ tests/test_trainers_args.py | 29 ------------------- trl/experimental/orpo/orpo_trainer.py | 4 +-- trl/trainer/__init__.py | 4 +++ trl/trainer/orpo_trainer.py | 5 ++-- 11 files changed, 45 insertions(+), 41 deletions(-) rename tests/{ => experimental}/test_orpo_trainer.py (100%) diff --git a/docs/source/community_tutorials.md b/docs/source/community_tutorials.md index 333aa973b5a..a412ee7d917 100644 --- a/docs/source/community_tutorials.md +++ b/docs/source/community_tutorials.md @@ -15,7 +15,7 @@ Community tutorials are made by active members of the Hugging Face community who | Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) | | Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) | | Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) | -| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | +| Preference Optimization | [`experimental.orpo.ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | | Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) | ### Videos diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index 958dfb3af52..7d98b8abe7b 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -395,7 +395,7 @@ Choosing the right dataset type depends on the task you are working on and the s | [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | | [`NashMDTrainer`] | [Prompt-only](#prompt-only) | | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | -| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`PPOTrainer`] | Tokenized language modeling | | [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) | | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 0f12f5ba1a0..d9cba0b9114 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -54,7 +54,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | -| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | +| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`experimental.orpo.ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. | | [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | | [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). | diff --git a/docs/source/index.md b/docs/source/index.md index 9d6584cc2b8..e0268d51868 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -41,8 +41,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL - [`SFTTrainer`] - [`DPOTrainer`] -- [`ORPOTrainer`] - [`experimental.bco.BCOTrainer`] 🧪 +- [`experimental.orpo.ORPOTrainer`] 🧪 - [`CPOTrainer`] - [`KTOTrainer`] diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index d3428313bb8..555f0858316 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -79,9 +79,9 @@ Here are some other factors to consider when choosing a programming language for ## Expected dataset type -ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +ORPO requires a [preference dataset](dataset_formats#preference). The [`experimental.orpo.ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. -Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. +Although the [`experimental.orpo.ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. ## Example script @@ -121,11 +121,11 @@ While training and evaluating, we record the following reward metrics: ## ORPOTrainer -[[autodoc]] ORPOTrainer +[[autodoc]] experimental.orpo.ORPOTrainer - train - save_model - push_to_hub ## ORPOConfig -[[autodoc]] ORPOConfig +[[autodoc]] experimental.orpo.ORPOConfig diff --git a/tests/test_orpo_trainer.py b/tests/experimental/test_orpo_trainer.py similarity index 100% rename from tests/test_orpo_trainer.py rename to tests/experimental/test_orpo_trainer.py diff --git a/tests/experimental/test_trainers_args.py b/tests/experimental/test_trainers_args.py index bd86bb61b5d..6b3e1bbb0f1 100644 --- a/tests/experimental/test_trainers_args.py +++ b/tests/experimental/test_trainers_args.py @@ -16,6 +16,7 @@ from transformers import AutoTokenizer from trl.experimental.bco import BCOConfig, BCOTrainer +from trl.experimental.orpo import ORPOConfig, ORPOTrainer from ..testing_utils import TrlTestCase, require_sklearn @@ -68,3 +69,30 @@ def test_bco(self): assert trainer.args.prompt_sample_size == 512 assert trainer.args.min_density_ratio == 0.2 assert trainer.args.max_density_ratio == 20.0 + + def test_orpo(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + training_args = ORPOConfig( + self.tmp_dir, + max_length=256, + max_prompt_length=64, + max_completion_length=64, + beta=0.5, + disable_dropout=False, + label_pad_token_id=-99, + padding_value=-99, + truncation_mode="keep_start", + # generate_during_eval=True, # ignore this one, it requires wandb + is_encoder_decoder=True, + model_init_kwargs={"trust_remote_code": True}, + dataset_num_proc=4, + ) + trainer = ORPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert not trainer.args.disable_dropout + assert trainer.args.label_pad_token_id == -99 diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 014ec6ac5da..1a6c8171c3f 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -28,8 +28,6 @@ NashMDTrainer, OnlineDPOConfig, OnlineDPOTrainer, - ORPOConfig, - ORPOTrainer, RewardConfig, RewardTrainer, SFTConfig, @@ -248,33 +246,6 @@ def test_online_dpo(self, beta_list): assert trainer.args.beta == (0.6 if not beta_list else [0.6, 0.7]) assert trainer.args.loss_type == "hinge" - def test_orpo(self): - model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" - tokenizer = AutoTokenizer.from_pretrained(model_id) - dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") - training_args = ORPOConfig( - self.tmp_dir, - max_length=256, - max_prompt_length=64, - max_completion_length=64, - beta=0.5, - disable_dropout=False, - label_pad_token_id=-99, - padding_value=-99, - truncation_mode="keep_start", - # generate_during_eval=True, # ignore this one, it requires wandb - is_encoder_decoder=True, - model_init_kwargs={"trust_remote_code": True}, - dataset_num_proc=4, - ) - trainer = ORPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) - assert trainer.args.max_length == 256 - assert trainer.args.max_prompt_length == 64 - assert trainer.args.max_completion_length == 64 - assert trainer.args.beta == 0.5 - assert not trainer.args.disable_dropout - assert trainer.args.label_pad_token_id == -99 - def test_reward(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" tokenizer = AutoTokenizer.from_pretrained(model_id) diff --git a/trl/experimental/orpo/orpo_trainer.py b/trl/experimental/orpo/orpo_trainer.py index 8ac3b0b6ed2..b490cdf6951 100644 --- a/trl/experimental/orpo/orpo_trainer.py +++ b/trl/experimental/orpo/orpo_trainer.py @@ -49,9 +49,9 @@ from transformers.utils import is_peft_available, is_torch_fx_proxy from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt -from ..base_trainer import BaseTrainer +from ...trainer.base_trainer import BaseTrainer from .orpo_config import ORPOConfig -from ..utils import ( +from ...trainer.utils import ( DPODataCollatorWithPadding, add_bos_token_if_needed, add_eos_token_if_needed, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 1f2e3ccec00..98846bf7159 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -54,6 +54,8 @@ "nash_md_trainer": ["NashMDTrainer"], "online_dpo_config": ["OnlineDPOConfig"], "online_dpo_trainer": ["OnlineDPOTrainer"], + "orpo_config": ["ORPOConfig"], + "orpo_trainer": ["ORPOTrainer"], "ppo_config": ["PPOConfig"], "ppo_trainer": ["PPOTrainer"], "prm_config": ["PRMConfig"], @@ -112,6 +114,8 @@ from .nash_md_trainer import NashMDTrainer from .online_dpo_config import OnlineDPOConfig from .online_dpo_trainer import OnlineDPOTrainer + from .orpo_config import ORPOConfig + from .orpo_trainer import ORPOTrainer from .ppo_config import PPOConfig from .ppo_trainer import PPOTrainer from .prm_config import PRMConfig diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 1c60e06c5f5..a89535ca615 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -16,6 +16,7 @@ from collections.abc import Callable from typing import Any +import torch import torch.nn as nn from datasets import Dataset from transformers import ( @@ -29,11 +30,11 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from ..experimental.orpo import ORPOTrainer as ExperimentalORPOTrainer +from ..experimental.orpo import ORPOTrainer as _ORPOTrainer from .orpo_config import ORPOConfig -class ORPOTrainer(ExperimentalORPOTrainer): +class ORPOTrainer(_ORPOTrainer): """ Initialize ORPOTrainer. From 9d7c53c447645cb9f5edb6ebc066de624e3137c6 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Wed, 5 Nov 2025 18:41:15 -0800 Subject: [PATCH 3/8] Fix ruff linting errors - remove unused imports - Remove unused 'import os' and 'import warnings' from trl/experimental/orpo/orpo_trainer.py - Remove unused 'from typing import Any' from trl/trainer/orpo_trainer.py --- trl/experimental/orpo/orpo_trainer.py | 2 -- trl/trainer/orpo_trainer.py | 1 - 2 files changed, 3 deletions(-) diff --git a/trl/experimental/orpo/orpo_trainer.py b/trl/experimental/orpo/orpo_trainer.py index b490cdf6951..fd68515a900 100644 --- a/trl/experimental/orpo/orpo_trainer.py +++ b/trl/experimental/orpo/orpo_trainer.py @@ -13,10 +13,8 @@ # limitations under the License. import inspect -import os import random import textwrap -import warnings from collections import defaultdict from collections.abc import Callable from contextlib import nullcontext diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index a89535ca615..7f5feb34489 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -14,7 +14,6 @@ import warnings from collections.abc import Callable -from typing import Any import torch import torch.nn as nn From 92e218b933e0b1aa70b62c3ff3a74851242be2fb Mon Sep 17 00:00:00 2001 From: Behrooz Date: Wed, 5 Nov 2025 19:58:55 -0800 Subject: [PATCH 4/8] Fix import path for testing_utils in ORPO test file --- tests/experimental/test_orpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/experimental/test_orpo_trainer.py b/tests/experimental/test_orpo_trainer.py index 554938293ca..95a82234909 100644 --- a/tests/experimental/test_orpo_trainer.py +++ b/tests/experimental/test_orpo_trainer.py @@ -19,7 +19,7 @@ from trl.experimental.orpo import ORPOConfig, ORPOTrainer -from .testing_utils import TrlTestCase, require_peft +from ..testing_utils import TrlTestCase, require_peft class TestORPOTrainer(TrlTestCase): From c2db59638331ef5085867fb2406a5872060ad3fc Mon Sep 17 00:00:00 2001 From: Behrooz Date: Wed, 5 Nov 2025 20:02:23 -0800 Subject: [PATCH 5/8] Fix import ordering in ORPO trainer --- trl/experimental/orpo/orpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/orpo/orpo_trainer.py b/trl/experimental/orpo/orpo_trainer.py index fd68515a900..29af144bb57 100644 --- a/trl/experimental/orpo/orpo_trainer.py +++ b/trl/experimental/orpo/orpo_trainer.py @@ -48,7 +48,6 @@ from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt from ...trainer.base_trainer import BaseTrainer -from .orpo_config import ORPOConfig from ...trainer.utils import ( DPODataCollatorWithPadding, add_bos_token_if_needed, @@ -59,6 +58,7 @@ peft_module_casting_to_bf16, selective_log_softmax, ) +from .orpo_config import ORPOConfig if is_peft_available(): From b6815e3f80a937d5001530e1214da74cbe40932e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 21 Nov 2025 05:43:23 +0000 Subject: [PATCH 6/8] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 52ed4df2c09bfa9f04242923ad72c067ea13be94 Author: Quentin Gallouédec Date: Thu Nov 20 21:41:23 2025 +0000 Fix style OpenEnv example commit a2639462fac330b0ac06c36dfa06bd840f305b61 Author: Sergio Paniego Blanco Date: Thu Nov 20 14:44:15 2025 +0100 Update OpenEnv guide with latest details (#4552) Co-authored-by: burtenshaw commit 1a9ff522308331ac32e3bd7076a09f3c1a922c1e Author: Kashif Rasul Date: Wed Nov 19 15:34:25 2025 +0100 [OpenEnv] browsergym example script (#4539) Co-authored-by: Sergio Paniego Blanco commit 6cbcd9413440ec4663a90c8b1cafd71b394f0711 Author: Sergio Paniego Blanco Date: Wed Nov 19 14:39:44 2025 +0100 Update OpenEnv example scripts (#4547) commit 85105890c185be95bf5d9fcdc030b18cecf2f302 Author: Sergio Paniego Blanco Date: Wed Nov 19 14:39:20 2025 +0100 Add OpenEnv Script examples to docs (#4533) commit e622196097109080b73584d598d4162e64fc6bea Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Mon Nov 17 03:12:30 2025 -0700 [Doc] Drop dummy reward and dataset for DeepMath-103K and accuracy reward (#4524) commit 1b1242cc6522feb4eb063feb20097a79b11b127a Author: Kashif Rasul Date: Fri Nov 14 20:51:41 2025 +0100 [OpenEnv] add vllm colocate mode to openenv scripts (#4510) Co-authored-by: Sergio Paniego Blanco Co-authored-by: Quentin Gallouédec commit f39d18a05d002df953f6cd6609415048548c5f85 Author: Fabio Milentiansen Sim Date: Fri Nov 14 23:39:02 2025 +0700 fix(GOLDTrainer): Resolve incorrect attribute access and VLLMClient.generate() output type (#4526) commit d45eaab3af6ab8c80a7c5b65df607a5152ed0f77 Author: Sergio Paniego Blanco Date: Fri Nov 14 12:12:09 2025 +0100 Add vLLM quantization option for colocate (#4496) Co-authored-by: Kashif Rasul commit a91d4b379a7e0af48bd879a7268f4337f3e22f36 Author: Sergio Paniego Blanco Date: Fri Nov 14 02:19:08 2025 +0100 Prevent upcasting norm layers in `prepare_model_for_kbit_training` (#4457) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 121318e281c33deb5c6df8c399af0c5cdf15506c Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Thu Nov 13 17:13:16 2025 -0800 docs: Extend CLI basic usage examples to all supported CLIs (#4425) Co-authored-by: Sergio Paniego Blanco Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 79183203fab0faede45356d0242f45f40b55289e Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Nov 13 13:20:52 2025 -0700 Remove test trainer args (#4517) commit 102dc4184c86a3c6d890b790a30dcd026e599071 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Nov 13 12:36:43 2025 -0700 Rename `flash-attn` to `flash-attn2` (#4514) Co-authored-by: Sergio Paniego Blanco commit 5de62b07e380762b4fd0ed24e9449385c4d820ee Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Nov 13 12:05:48 2025 -0700 Add step time metric to GRPO Trainer for performance tracking (#4516) Co-authored-by: lewtun commit f1e6377e4f301cafb0d8dc29b9afe8da930facfe Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Thu Nov 13 11:01:19 2025 -0800 Move PPOTrainer to trl.experimental.ppo (#4482) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 01f497e2e11350f81d2777eaa53f2282e675201e Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Thu Nov 13 10:14:58 2025 -0800 Move NashMDTrainer to experimental module (#4477) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit b6c838aa24c3b58930bd4faba9704c0223ee0590 Author: Quentin Gallouédec Date: Thu Nov 13 16:53:26 2025 +0000 `aws-general-8-plus` runner for Docker build commit ed5c7bb5b07845b18f40627d69fe133884f72f39 Author: YangKai0616 Date: Fri Nov 14 00:42:48 2025 +0800 [Bug Fix] OnlineDPOTrainer with vLLM Server Mode (#4500) commit ded9bc6164f3bdbe1df35adb77eb8be4594f94b3 Author: lewtun Date: Thu Nov 13 17:33:59 2025 +0100 Fix Docker images for Liger (#4522) commit fd04760f594e9262cbf9abaccfef2bad05569775 Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Thu Nov 13 11:31:10 2025 +0000 Paper Index: Change `num_completions` to `num_generations` (#4515) commit b7918c0f3bdba2e327bad7abf01aa1becfab3565 Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Wed Nov 12 20:35:44 2025 -0800 Move GKDTrainer to experimental module (#4474) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 07b5011b7864d4a27b71e30963a0a4ca610da3fd Author: Tamoghno Kandar <55907205+tamoghnokandar@users.noreply.github.com> Date: Wed Nov 12 20:07:33 2025 -0800 Replace flash attention2 with kernels-community/flash-attn2 (#4426) Co-authored-by: Quentin Gallouédec Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 7a57fd41e41b33c7fd03a24abb45dd1cdebcbc49 Author: Yuxian Gu Date: Thu Nov 13 11:16:20 2025 +0800 MiniLLM: Fix arguments in config & add to documentation index (#4518) commit a145eaf81ac664d62c63d7d088f4d5c7d261f5b2 Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Wed Nov 12 16:35:46 2025 -0800 refactor: Move CPOTrainer to experimental module (#4470) commit d2dc717e03062d129e60865eca3e85ed8fa73538 Author: Taha Yassine <40228615+taha-yassine@users.noreply.github.com> Date: Thu Nov 13 00:56:47 2025 +0100 Replace `wandb_log_unique_prompts` with `log_unique_prompts` (#4508) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 799b39b86408bc8d356cd0fbd398741becdfd059 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Nov 12 16:21:05 2025 -0700 `device_map` and `dtype` to `"auto"` by default (#4509) Co-authored-by: Sergio Paniego Blanco commit a6a2beb937377df7078537a3454483a01000868b Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Nov 12 09:42:31 2025 -0700 Add temporary workaround for `lr_scheduler_kwargs` dtype issue in Transformers 4.57.0 (#4513) commit 346701ae6e5cf4b1797734732fc8040bbbec9e55 Author: lewtun Date: Wed Nov 12 17:42:18 2025 +0100 Replace accelerate logging with stdlib in CLI (#4512) commit 4db63af98b6437b18208a34455b6b10692086800 Author: Quentin Gallouédec Date: Wed Nov 12 02:19:51 2025 +0000 Fix GRPO unsqueeze advantages commit ecb2811535daf0aabcd3cb88d89909cf08fb89ad Author: Yuxian Gu Date: Wed Nov 12 10:17:22 2025 +0800 Add MiniLLM Trainer (#4504) Co-authored-by: Quentin Gallouédec commit 89e46883a1c6b2feb36eff99b937095b39de77da Author: Taha Yassine <40228615+taha-yassine@users.noreply.github.com> Date: Tue Nov 11 20:36:23 2025 +0100 Add support for images inside tables with Trackio completions logging (#4505) commit 2d3279c2c2eac73c1f84a3efd3bf414913146b90 Author: lewtun Date: Tue Nov 11 19:22:25 2025 +0100 Tweak description for vLLM sleep mode (#4506) Co-authored-by: Quentin Gallouédec commit 02a34777c36173ea53c8ce9979db50f13c1aaac8 Author: Luke Hinds Date: Mon Nov 10 16:41:51 2025 +0000 Fix link to OpenEnv docs (#4502) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit aaed6c1600ff4f3e0ccc6b3b8183c98d26390491 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sat Nov 8 08:20:48 2025 -0700 Consistency regarding relative imports (#4498) commit 20760ba3ac7092432f67a5e3e9b2c624f168fdb0 Author: burtenshaw Date: Fri Nov 7 10:50:50 2025 +0100 [DOCS] update and fix openenv (#4490) Co-authored-by: Kashif Rasul Co-authored-by: Sergio Paniego Blanco commit 64cfca42297311e1b1c41a32d0867b16619d577c Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Thu Nov 6 22:47:04 2025 -0800 Move judges to experimental submodule (#4439) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 97ca1a2569367d4d1d30bedf8d29377b3d4d1ec4 Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Fri Nov 7 00:20:15 2025 +0000 Fix bugs in CISPO conditions (#4499) commit ffb3dd5d2e9e4d3b866527b861313be63981e024 Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Thu Nov 6 16:03:00 2025 -0800 docs: Add PEFT subsection to reducing memory usage guide (#4430) Co-authored-by: Sergio Paniego Blanco commit 43b6541aa46669a5ea79327c989730b282f18076 Author: SolarWindRider <31797478+SolarWindRider@users.noreply.github.com> Date: Fri Nov 7 06:55:34 2025 +0800 Support completion bootstrap for VLM in GRPO/RLOO (#4452) Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 642b721ee52c265835e2826c20316165b3f7a24f Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Thu Nov 6 22:33:00 2025 +0000 ScaleRL: Add CISPO Loss (#4495) commit 32e9c9fa6a4def84b128042092422534920903f1 Author: Ishita Bhattacharyya <139248026+ishitab02@users.noreply.github.com> Date: Fri Nov 7 03:37:43 2025 +0530 ⛴️ Add kernels to Docker images (#4445) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 1bcfc500eb548344de0567f2c9d277379b3db940 Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Thu Nov 6 13:40:12 2025 -0800 Move XPOTrainer to trl.experimental.xpo (#4485) Co-authored-by: Invidia19 <54266187+Invidia19@users.noreply.github.com> Co-authored-by: Quentin Gallouédec commit 37942bc19fb7d2bc8c43a839b32214ede53a0976 Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Thu Nov 6 21:32:03 2025 +0000 Buffer samples based on group level stds. (#4492) commit 66cd02a6f50ae6a9dcb252bd86fe434a14299651 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu Nov 6 20:58:25 2025 +0100 Add tiny model Qwen3VLForConditionalGeneration to CI (#4494) commit 32febb491b386881fe75137e9990fb0f1b5cae2c Author: Sergio Paniego Blanco Date: Thu Nov 6 18:21:56 2025 +0100 Add LFM2 to SFT notebook examples (#4455) --- .github/workflows/docker-build.yml | 6 +- README.md | 13 +- docker/trl-dev/Dockerfile | 5 +- docker/trl/Dockerfile | 6 +- docs/source/_toctree.yml | 30 +- docs/source/clis.md | 335 ++++- docs/source/cpo_trainer.md | 22 +- docs/source/dataset_formats.md | 14 +- docs/source/dpo_trainer.md | 2 +- docs/source/example_overview.md | 20 +- docs/source/gkd_trainer.md | 17 +- docs/source/gold_trainer.md | 4 +- docs/source/grpo_trainer.md | 38 +- docs/source/index.md | 13 +- docs/source/judges.md | 26 +- docs/source/kernels_hub.md | 8 +- docs/source/liger_kernel_integration.md | 2 +- docs/source/minillm.md | 67 + docs/source/nash_md_trainer.md | 19 +- docs/source/online_dpo_trainer.md | 7 +- docs/source/openenv.md | 355 ++++-- docs/source/paper_index.md | 60 +- docs/source/peft_integration.md | 3 +- docs/source/ppo_trainer.md | 10 +- docs/source/quickstart.md | 9 +- docs/source/reducing_memory_usage.md | 37 +- docs/source/rloo_trainer.md | 20 +- docs/source/vllm_integration.md | 104 +- docs/source/xpo_trainer.md | 22 +- examples/datasets/deepmath_103k.py | 98 ++ examples/notebooks/sft_trl_lora_qlora.ipynb | 5 +- examples/scripts/cpo.py | 3 +- examples/scripts/evals/judge_tldr.py | 2 +- examples/scripts/gkd.py | 3 +- examples/scripts/nash_md.py | 7 +- examples/scripts/online_dpo.py | 4 +- examples/scripts/openenv/browsergym.py | 572 +++++++++ examples/scripts/openenv/catch.py | 176 +-- examples/scripts/openenv/echo.py | 120 +- examples/scripts/openenv/wordle.py | 309 ++--- examples/scripts/ppo/ppo.py | 11 +- examples/scripts/ppo/ppo_tldr.py | 11 +- examples/scripts/xpo.py | 7 +- pyproject.toml | 5 + scripts/generate_tiny_models.py | 13 + tests/{ => experimental}/test_cpo_trainer.py | 4 +- tests/{ => experimental}/test_gkd_trainer.py | 4 +- .../test_grpo_with_replay_buffer_trainer.py | 4 +- tests/{ => experimental}/test_judges.py | 4 +- tests/experimental/test_minillm_trainer.py | 57 + .../test_nash_md_trainer.py | 4 +- tests/{ => experimental}/test_ppo_trainer.py | 6 +- tests/experimental/test_trainers_args.py | 98 -- tests/{ => experimental}/test_xpo_trainer.py | 5 +- tests/test_callbacks.py | 2 +- tests/test_grpo_trainer.py | 10 +- tests/test_online_dpo_trainer.py | 18 +- tests/test_sft_trainer.py | 25 +- tests/test_trainers_args.py | 315 ----- tests/testing_utils.py | 23 +- trl/cli.py | 16 +- trl/experimental/bco/bco_config.py | 10 + trl/experimental/bco/bco_trainer.py | 6 +- trl/experimental/cpo/__init__.py | 19 + trl/experimental/cpo/cpo_config.py | 228 ++++ trl/experimental/cpo/cpo_trainer.py | 1089 +++++++++++++++++ trl/experimental/gfpo/gfpo_trainer.py | 3 + trl/experimental/gkd/__init__.py | 19 + trl/experimental/gkd/gkd_config.py | 112 ++ trl/experimental/gkd/gkd_trainer.py | 440 +++++++ trl/experimental/gold/gold_config.py | 7 +- trl/experimental/gold/gold_trainer.py | 19 +- .../grpo_with_replay_buffer_config.py | 2 +- .../grpo_with_replay_buffer_trainer.py | 24 +- trl/experimental/gspo_token/grpo_trainer.py | 3 +- trl/experimental/judges/__init__.py | 36 + trl/experimental/judges/judges.py | 457 +++++++ trl/experimental/minillm/__init__.py | 19 + trl/experimental/minillm/minillm_config.py | 145 +++ trl/experimental/minillm/minillm_trainer.py | 396 ++++++ trl/experimental/nash_md/__init__.py | 19 + trl/experimental/nash_md/nash_md_config.py | 46 + trl/experimental/nash_md/nash_md_trainer.py | 489 ++++++++ trl/experimental/openenv/__init__.py | 18 + trl/experimental/openenv/utils.py | 137 +++ trl/experimental/ppo/__init__.py | 19 + trl/experimental/ppo/ppo_config.py | 135 ++ trl/experimental/ppo/ppo_trainer.py | 836 +++++++++++++ trl/experimental/xpo/__init__.py | 19 + trl/experimental/xpo/xpo_config.py | 44 + trl/experimental/xpo/xpo_trainer.py | 538 ++++++++ trl/extras/vllm_client.py | 3 + trl/mergekit_utils.py | 2 +- trl/models/utils.py | 10 +- trl/rewards/accuracy_rewards.py | 20 +- trl/scripts/grpo.py | 18 + trl/trainer/callbacks.py | 5 +- trl/trainer/cpo_config.py | 207 +--- trl/trainer/cpo_trainer.py | 1088 +--------------- trl/trainer/dpo_config.py | 10 + trl/trainer/gkd_config.py | 101 +- trl/trainer/gkd_trainer.py | 440 +------ trl/trainer/grpo_config.py | 70 +- trl/trainer/grpo_trainer.py | 332 ++--- trl/trainer/judges.py | 532 ++------ trl/trainer/kto_config.py | 10 + trl/trainer/kto_trainer.py | 6 +- trl/trainer/model_config.py | 8 +- trl/trainer/nash_md_config.py | 35 +- trl/trainer/nash_md_trainer.py | 488 +------- trl/trainer/online_dpo_config.py | 10 + trl/trainer/online_dpo_trainer.py | 28 +- trl/trainer/orpo_config.py | 30 +- trl/trainer/orpo_trainer.py | 79 +- trl/trainer/ppo_config.py | 128 +- trl/trainer/ppo_trainer.py | 835 +------------ trl/trainer/prm_config.py | 10 + trl/trainer/reward_config.py | 10 + trl/trainer/reward_trainer.py | 3 +- trl/trainer/rloo_config.py | 57 +- trl/trainer/rloo_trainer.py | 84 +- trl/trainer/sft_config.py | 10 + trl/trainer/sft_trainer.py | 4 +- trl/trainer/utils.py | 15 +- trl/trainer/xpo_config.py | 33 +- trl/trainer/xpo_trainer.py | 529 +------- 126 files changed, 8028 insertions(+), 5771 deletions(-) create mode 100644 docs/source/minillm.md create mode 100644 examples/datasets/deepmath_103k.py create mode 100644 examples/scripts/openenv/browsergym.py rename tests/{ => experimental}/test_cpo_trainer.py (98%) rename tests/{ => experimental}/test_gkd_trainer.py (99%) rename tests/{ => experimental}/test_judges.py (95%) create mode 100644 tests/experimental/test_minillm_trainer.py rename tests/{ => experimental}/test_nash_md_trainer.py (98%) rename tests/{ => experimental}/test_ppo_trainer.py (97%) delete mode 100644 tests/experimental/test_trainers_args.py rename tests/{ => experimental}/test_xpo_trainer.py (97%) delete mode 100644 tests/test_trainers_args.py create mode 100644 trl/experimental/cpo/__init__.py create mode 100644 trl/experimental/cpo/cpo_config.py create mode 100644 trl/experimental/cpo/cpo_trainer.py create mode 100644 trl/experimental/gkd/__init__.py create mode 100644 trl/experimental/gkd/gkd_config.py create mode 100644 trl/experimental/gkd/gkd_trainer.py create mode 100644 trl/experimental/judges/__init__.py create mode 100644 trl/experimental/judges/judges.py create mode 100644 trl/experimental/minillm/__init__.py create mode 100644 trl/experimental/minillm/minillm_config.py create mode 100644 trl/experimental/minillm/minillm_trainer.py create mode 100644 trl/experimental/nash_md/__init__.py create mode 100644 trl/experimental/nash_md/nash_md_config.py create mode 100644 trl/experimental/nash_md/nash_md_trainer.py create mode 100644 trl/experimental/openenv/__init__.py create mode 100644 trl/experimental/openenv/utils.py create mode 100644 trl/experimental/ppo/__init__.py create mode 100644 trl/experimental/ppo/ppo_config.py create mode 100644 trl/experimental/ppo/ppo_trainer.py create mode 100644 trl/experimental/xpo/__init__.py create mode 100644 trl/experimental/xpo/xpo_config.py create mode 100644 trl/experimental/xpo/xpo_trainer.py diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 5a5fae6bf4c..2fc2192cafb 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -13,7 +13,8 @@ concurrency: jobs: trl: name: "Build and push TRL Docker image" - runs-on: ubuntu-latest + runs-on: + group: aws-general-8-plus steps: - name: Checkout code uses: actions/checkout@v4 @@ -52,7 +53,8 @@ jobs: trl-dev: name: "Build and push TRL Dev Docker image" - runs-on: ubuntu-latest + runs-on: + group: aws-general-8-plus steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/README.md b/README.md index 437b4b54cb4..c0a49208262 100644 --- a/README.md +++ b/README.md @@ -21,11 +21,11 @@ **OpenEnv Integration:** TRL now supports **[OpenEnv](https://huggingface.co/blog/openenv)**, the open-source framework from Meta for defining, deploying, and interacting with environments in reinforcement learning and agentic workflows. -Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](openenv). +Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](https://huggingface.co/docs/trl/openenv). ## Overview -TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups. +TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Group Realtive Policy Optimization (GRPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups. ## Highlights @@ -92,16 +92,13 @@ trainer.train() ```python from datasets import load_dataset from trl import GRPOTrainer +from trl.rewards import accuracy_reward -dataset = load_dataset("trl-lib/tldr", split="train") - -# Dummy reward function: count the number of unique characters in the completions -def reward_num_unique_chars(completions, **kwargs): - return [len(set(c)) for c in completions] +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", - reward_funcs=reward_num_unique_chars, + reward_funcs=accuracy_reward, train_dataset=dataset, ) trainer.train() diff --git a/docker/trl-dev/Dockerfile b/docker/trl-dev/Dockerfile index c8557048d7c..9a756a8821d 100644 --- a/docker/trl-dev/Dockerfile +++ b/docker/trl-dev/Dockerfile @@ -1,6 +1,5 @@ -FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime +FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* RUN pip install --upgrade pip uv RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]" -RUN uv pip install --system hf_transfer liger_kernel trackio peft -RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl \ No newline at end of file +RUN uv pip install --system kernels liger_kernel peft trackio \ No newline at end of file diff --git a/docker/trl/Dockerfile b/docker/trl/Dockerfile index 61a2b0dd278..8b6e2842a38 100644 --- a/docker/trl/Dockerfile +++ b/docker/trl/Dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime +FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* RUN pip install --upgrade pip uv -RUN uv pip install --system trl[liger,peft,vlm] hf_transfer trackio -RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl \ No newline at end of file +RUN uv pip install --system trl[liger,peft,vlm] kernels trackio \ No newline at end of file diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a7555a1e7ed..6fd438151ab 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -56,22 +56,14 @@ title: Examples - sections: - sections: # Sorted alphabetically - - local: cpo_trainer - title: CPO - local: dpo_trainer title: DPO - local: online_dpo_trainer title: Online DPO - - local: gkd_trainer - title: GKD - local: grpo_trainer title: GRPO - local: kto_trainer title: KTO - - local: nash_md_trainer - title: Nash-MD - - local: ppo_trainer - title: PPO - local: prm_trainer title: PRM - local: reward_trainer @@ -80,15 +72,11 @@ title: RLOO - local: sft_trainer title: SFT - - local: xpo_trainer - title: XPO title: Trainers - local: models title: Model Classes - local: model_utils title: Model Utilities - - local: judges - title: Judges - local: callbacks title: Callbacks - local: data_utils @@ -107,14 +95,32 @@ title: BEMA for Reference Model - local: bco_trainer title: BCO + - local: cpo_trainer + title: CPO - local: gfpo title: GFPO + - local: gkd_trainer + title: GKD - local: gold_trainer title: GOLD - local: grpo_with_replay_buffer title: GRPO With Replay Buffer - local: gspo_token title: GSPO-token + - local: judges + title: Judges + - local: minillm + title: MiniLLM + - local: nash_md_trainer + title: Nash-MD + - local: orpo_trainer + title: ORPO + - local: papo_trainer + title: PAPO + - local: ppo_trainer + title: PPO + - local: xpo_trainer + title: XPO - local: openenv title: OpenEnv Integration - local: orpo_trainer diff --git a/docs/source/clis.md b/docs/source/clis.md index 666584decf4..54c8c1055f9 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -26,7 +26,7 @@ Currently supported commands are: You can launch training directly from the CLI by specifying required arguments like the model and dataset: - + ```bash @@ -53,6 +53,35 @@ trl reward \ --dataset_name trl-lib/ultrafeedback_binarized ``` + + + +```bash +trl grpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward +``` + + + + +```bash +trl rloo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward +``` + + + + +```bash +trl kto \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/kto-mix-14k +``` + @@ -60,7 +89,7 @@ trl reward \ To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file: - + ```yaml @@ -105,6 +134,55 @@ Launch with: trl reward --config reward_config.yaml ``` + + + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/kto-mix-14k +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + @@ -114,8 +192,8 @@ TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelera You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch). - - + + ```bash trl sft \ @@ -124,8 +202,7 @@ trl sft \ --num_processes 4 ``` - - +or, with a config file: ```yaml # sft_config.yaml @@ -141,7 +218,7 @@ trl sft --config sft_config.yaml ``` - + ```bash trl dpo \ @@ -150,8 +227,7 @@ trl dpo \ --num_processes 4 ``` - - +or, with a config file: ```yaml # dpo_config.yaml @@ -167,7 +243,7 @@ trl dpo --config dpo_config.yaml ``` - + ```bash trl reward \ @@ -176,8 +252,7 @@ trl reward \ --num_processes 4 ``` - - +or, with a config file: ```yaml # reward_config.yaml @@ -192,6 +267,87 @@ Launch with: trl reward --config reward_config.yaml ``` + + + +```bash +trl grpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +num_processes: 4 +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```bash +trl rloo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +num_processes: 4 +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```bash +trl kto \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/kto-mix-14k \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/kto-mix-14k +num_processes: 4 +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + @@ -220,8 +376,8 @@ To use one of these, just pass the name to `--accelerate_config`. TRL will autom #### Example Usage - - + + ```bash trl sft \ @@ -230,8 +386,7 @@ trl sft \ --accelerate_config zero2 # or path/to/my/accelerate/config.yaml ``` - - +or, with a config file: ```yaml # sft_config.yaml @@ -247,7 +402,7 @@ trl sft --config sft_config.yaml ``` - + ```bash trl dpo \ @@ -256,8 +411,7 @@ trl dpo \ --accelerate_config zero2 # or path/to/my/accelerate/config.yaml ``` - - +or, with a config file: ```yaml # dpo_config.yaml @@ -273,7 +427,7 @@ trl dpo --config dpo_config.yaml ``` - + ```bash trl reward \ @@ -282,8 +436,7 @@ trl reward \ --accelerate_config zero2 # or path/to/my/accelerate/config.yaml ``` - - +or, with a config file: ```yaml # reward_config.yaml @@ -298,6 +451,87 @@ Launch with: trl reward --config reward_config.yaml ``` + + + +```bash +trl grpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```bash +trl rloo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```bash +trl kto \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/kto-mix-14k \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/kto-mix-14k +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + @@ -305,7 +539,7 @@ trl reward --config reward_config.yaml You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data. - + ```yaml @@ -356,6 +590,61 @@ Launch with: trl reward --config reward_config.yaml ``` + + + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: HuggingFaceH4/Polaris-Dataset-53K + - path: trl-lib/DeepMath-103K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: HuggingFaceH4/Polaris-Dataset-53K + - path: trl-lib/DeepMath-103K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: trl-lib/kto-mix-14k + - path: argilla/ultrafeedback-binarized-preferences-cleaned +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + diff --git a/docs/source/cpo_trainer.md b/docs/source/cpo_trainer.md index 3dcdb0e11cd..e1ff2a198a4 100644 --- a/docs/source/cpo_trainer.md +++ b/docs/source/cpo_trainer.md @@ -24,7 +24,7 @@ Below is the script to train the model: ```python # train_cpo.py from datasets import load_dataset -from trl import CPOConfig, CPOTrainer +from trl.experimental.cpo import CPOConfig, CPOTrainer from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") @@ -44,7 +44,7 @@ accelerate launch train_cpo.py ## Expected dataset type -CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +CPO requires a [preference dataset](dataset_formats#preference). The [`experimental.cpo.CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Example script @@ -80,31 +80,31 @@ The abstract from the paper is the following: > Direct Preference Optimization (DPO) is a widely used offline preference optimization algorithm that reparameterizes reward functions in reinforcement learning from human feedback (RLHF) to enhance simplicity and training stability. In this work, we propose SimPO, a simpler yet more effective approach. The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as the implicit reward. This reward formulation better aligns with model generation and eliminates the need for a reference model, making it more compute and memory efficient. Additionally, we introduce a target reward margin to the Bradley-Terry objective to encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. We compare SimPO to DPO and its latest variants across various state-of-the-art training setups, including both base and instruction-tuned models like Mistral and Llama3. We evaluated on extensive instruction-following benchmarks, including AlpacaEval 2, MT-Bench, and the recent challenging Arena-Hard benchmark. Our results demonstrate that SimPO consistently and significantly outperforms existing approaches without substantially increasing response length. Specifically, SimPO outperforms DPO by up to 6.4 points on AlpacaEval 2 and by up to 7.5 points on Arena-Hard. Our top-performing model, built on Llama3-8B-Instruct, achieves a remarkable 44.7 length-controlled win rate on AlpacaEval 2 -- surpassing Claude 3 Opus on the leaderboard, and a 33.8 win rate on Arena-Hard -- making it the strongest 8B open-source model. -The SimPO loss is integrated in the [`CPOTrainer`], as it's an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, just turn on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and set the `simpo_gamma` to a recommended value. +The SimPO loss is integrated in the [`experimental.cpo.CPOTrainer`], as it's an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, just turn on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`experimental.cpo.CPOConfig`] and set the `simpo_gamma` to a recommended value. ### CPO-SimPO -We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`]. +We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`experimental.cpo.CPOConfig`]. ### AlphaPO -The [AlphaPO -- Reward shape matters for LLM alignment](https://huggingface.co/papers/2501.03884) (AlphaPO) method by Aman Gupta, Shao Tang, Qingquan Song, Sirou Zhu, [Jiwoo Hong](https://huggingface.co/JW17), Ankan Saha, Viral Gupta, Noah Lee, Eunki Kim, Jason Zhu, Natesh Pillai, and S. Sathiya Keerthi is also implemented in the [`CPOTrainer`]. AlphaPO is an alternative method that applies a transformation to the reward function shape in the context of SimPO loss. The abstract from the paper is the following: +The [AlphaPO -- Reward shape matters for LLM alignment](https://huggingface.co/papers/2501.03884) (AlphaPO) method by Aman Gupta, Shao Tang, Qingquan Song, Sirou Zhu, [Jiwoo Hong](https://huggingface.co/JW17), Ankan Saha, Viral Gupta, Noah Lee, Eunki Kim, Jason Zhu, Natesh Pillai, and S. Sathiya Keerthi is also implemented in the [`experimental.cpo.CPOTrainer`]. AlphaPO is an alternative method that applies a transformation to the reward function shape in the context of SimPO loss. The abstract from the paper is the following: > Reinforcement Learning with Human Feedback (RLHF) and its variants have made huge strides toward the effective alignment of large language models (LLMs) to follow instructions and reflect human values. More recently, Direct Alignment Algorithms (DAAs) have emerged in which the reward modeling stage of RLHF is skipped by characterizing the reward directly as a function of the policy being learned. Some popular examples of DAAs include Direct Preference Optimization (DPO) and Simple Preference Optimization (SimPO). These methods often suffer from likelihood displacement, a phenomenon by which the probabilities of preferred responses are often reduced undesirably. In this paper, we argue that, for DAAs the reward (function) shape matters. We introduce AlphaPO, a new DAA method that leverages an α-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and overoptimization. Compared to SimPO, one of the best performing DAAs, AlphaPO leads to about 7% to 10% relative improvement in alignment performance for the instruct versions of Mistral-7B and Llama3-8B while achieving 15% to 50% relative improvement over DPO on the same models. The analysis and results presented highlight the importance of the reward shape and how one can systematically change it to affect training dynamics, as well as improve alignment performance. -To use this loss as described in the paper, we can set the `loss_type="alphapo"` which automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values in the [`CPOConfig`]. Alternatively, you can manually set `loss_type="simpo"`, `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values. Other variants of this method are also possible, such as setting `loss_type="ipo"` and `alpha` to any non-zero value. +To use this loss as described in the paper, we can set the `loss_type="alphapo"` which automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values in the [`experimental.cpo.CPOConfig`]. Alternatively, you can manually set `loss_type="simpo"`, `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values. Other variants of this method are also possible, such as setting `loss_type="ipo"` and `alpha` to any non-zero value. ## Loss functions -The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported: +The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`experimental.cpo.CPOConfig`]. The following loss functions are supported: | `loss_type=` | Description | | --- | --- | | `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | | `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. | | `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). | -| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. | -| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. | +| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`experimental.cpo.CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`experimental.cpo.CPOConfig`] and `simpo_gamma` to a recommended value. | +| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`experimental.cpo.CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. | ### For Mixture of Experts Models: Enabling the auxiliary loss @@ -116,11 +116,11 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype ## CPOTrainer -[[autodoc]] CPOTrainer +[[autodoc]] experimental.cpo.CPOTrainer - train - save_model - push_to_hub ## CPOConfig -[[autodoc]] CPOConfig +[[autodoc]] experimental.cpo.CPOConfig diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index 7d98b8abe7b..5f86257875d 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -387,21 +387,21 @@ Choosing the right dataset type depends on the task you are working on and the s | Trainer | Expected dataset type | | --- | --- | -| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | -| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | | [`GRPOTrainer`] | [Prompt-only](#prompt-only) | | [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | -| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | -| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`PPOTrainer`] | Tokenized language modeling | | [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) | | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | [`RLOOTrainer`] | [Prompt-only](#prompt-only) | | [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) | -| [`XPOTrainer`] | [Prompt-only](#prompt-only) | +| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | +| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) | +| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling | +| [`experimental.xpo.XPOTrainer`] | [Prompt-only](#prompt-only) | ## Using any dataset with TRL: preprocessing and conversion diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 8e9f0ac41b5..9e524c6a940 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -253,7 +253,7 @@ model = AutoModelForCausalLM.from_pretrained( "mistralai/mixtral-8x7b-v0.1", load_in_4bit=True, quantization_config=bnb_config, - attn_implementation="flash_attention_2", + attn_implementation="kernels-community/flash-attn2", dtype=torch.bfloat16, device_map="auto", ) diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index d9cba0b9114..598f8150bdd 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -37,26 +37,30 @@ These notebooks are easier to run and are designed for quick experimentation wit Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as `SFTTrainer`, `PPOTrainer`, `DPOTrainer`, `GRPOTrainer`, and more. - File | Description | +| File | Description | | --- | --- | | [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty, and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. | -| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | +| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. | | [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. | -| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`OpenAIPairwiseJudge`] to judge model generations. | -| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. | +| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`experimental.judges.HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. | +| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`experimental.gkd.GKDTrainer`] to fine-tune a model. | | [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. | | [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. | | [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. | | [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. | | [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. | | [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. | -| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. | +| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | +| [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM | +| [`examples/scripts/openenv/catch.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/catch.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Catch environment (OpenSpiel) and vLLM | +| [`examples/scripts/openenv/echo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Echo environment and vLLM. | +| [`examples/scripts/openenv/wordle.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/wordle.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Wordle environment and vLLM. | | [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`experimental.orpo.ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | -| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. | -| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | +| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. | +| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | | [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). | | [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. | | [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. | @@ -66,7 +70,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models, so users may see unexpected behaviour in other model architectures. | | [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. | | [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. | -| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. | +| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`experimental.xpo.XPOTrainer`] to fine-tune a model. | ## Distributed Training (for scripts) diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index 73be330c637..b703a1712b9 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -19,26 +19,23 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface. ## Usage tips -The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely: +The [`experimental.gkd.GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`experimental.gkd.GKDConfig`] namely: * `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch. -* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher. +* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher. * `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two. The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. > [!WARNING] -> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. +> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. The basic API is as follows: ```python from datasets import Dataset -from trl import GKDConfig, GKDTrainer -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, -) +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl.experimental.gkd import GKDConfig, GKDTrainer NUM_DUMMY_SAMPLES = 100 @@ -92,11 +89,11 @@ The dataset should be formatted as a list of "messages" where each message is a ## GKDTrainer -[[autodoc]] GKDTrainer +[[autodoc]] experimental.gkd.GKDTrainer - train - save_model - push_to_hub ## GKDConfig -[[autodoc]] GKDConfig +[[autodoc]] experimental.gkd.GKDConfig diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index 61f68b2029a..ae2591bb3bd 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -13,7 +13,7 @@ Key capabilities: 1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ. 2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher. -3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`GKDTrainer`](./gkd_trainer.md), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run. +3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`experimental.gkd.GKDTrainer`], so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run. > [!NOTE] > GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on. @@ -27,7 +27,7 @@ messages). Important configuration flags on [`GOLDConfig`] include: * `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens. * `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid matched/unmatched loss. -* `beta`, `lmbda`, `seq_kd` – inherited from `GKDConfig`, controlling the generalized JSD interpolation and on-policy +* `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy sampling ratio. A minimal end-to-end example: diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index a3d99953706..92a40b009d2 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -14,10 +14,10 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi ## Quick start -This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here: +This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [DeepMath-103K dataset](https://huggingface.co/datasets/trl-lib/DeepMath-103K). You can view the data in the dataset here: -To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md). - ## Advanced Example -Let's level this up a bit by training a model to interact with a more complex environment. We'll use the game word guessing game [wordle](https://www.nytimes.com/games/wordle/index.html) from the `textarena` environment. +Let's level this up a bit by training a model to interact with a more complex environment. We'll use the game word guessing game [wordle](https://www.nytimes.com/games/wordle/index.html) from the [`TextArena`](https://meta-pytorch.org/OpenEnv/environments/textarena/) environment. ### The TextArena Environment [TextArena](https://huggingface.co/papers/2504.11442) is an open-source collection of competitive text-based games designed to evaluate reasoning skills in LLMs using textual games like Wordle, Snake, Tic-Tac-Toe, and more. Research has shown that such games improve model performance on reasoning tasks. -![image of textarena](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/text_arena_evals.png) +![image of TextArena](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/text_arena_evals.png) -We will use the `textarena` environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them. +We will use the `TextArena` environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them. ### Wordle Wordle is a useful game to train a model on because it requires the model to reason about the word and the feedback provided by the environment. Also, it is a purely language based game that requires no external tools or knowledge. Furthermore, we found that models from 1 billion parameters and up are able to improve on wordle and only require 8 tokens to generate a guess, which makes the game a good benchmark to experiment with Reinforcement Learning environments without significant compute requirements. > [!NOTE] How does Wordle work? -> Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or less. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment. +> Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or fewer. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment. > > For example, if the wordle environment returns the following feedback: > @@ -233,9 +370,9 @@ Wordle is a useful game to train a model on because it requires the model to rea > G U E S S > X G Y X X > ``` -> The model has guessed the word "GUESS" and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game blank, green, and yellow. From this feedback, the model should learn that the word is "GUESS" is incorrect. The letter "E" is in the word, but in the wrong position. The letter "U" is correct and in the correct position. +> The model has guessed the word "GUESS" and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game as blank, green, and yellow. From this feedback, the model should learn that the word "GUESS" is incorrect. The letter "E" is in the word, but in the wrong position. The letter "U" is correct and in the correct position. -In the TextArena environment, reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of `reward_funcs` and `rollout_func` allows you to add any custom reward function you want to the script. +In the TextArena environment, a reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of `reward_funcs` and `rollout_func` allows you to add any custom reward function you want to the script. ### Rollout Function @@ -243,12 +380,12 @@ The rollout function runs one full Wordle episode, prompting the model for a gue ```python def rollout_once( + trainer: GRPOTrainer, env: TextArenaEnv, tokenizer: AutoTokenizer, - args: GRPOConfig, dataset_prompt: str, - cli_args: argparse.Namespace, system_prompt: str, + max_turns: int, ) -> dict[str, list]: result = env.reset() observation = result.observation @@ -263,7 +400,7 @@ def rollout_once( correct_scores: list[float] = [] guess_counts: dict[str, int] = {} - for _turn in range(cli_args.max_turns): + for _turn in range(max_turns): # when the game is over the environment will return a done=True if result.done: break @@ -282,20 +419,15 @@ def rollout_once( enable_thinking=False, ) - # generate the completion from the model using vLLM - vllm_result = request_vllm_completion( - prompt_text, - args, - endpoint=cli_args.vllm_endpoint, - timeout=cli_args.request_timeout, - fallback=cli_args, - ) - prompt_ids.extend(vllm_result["prompt_ids"]) - completion_ids.extend(vllm_result["completion_ids"]) - logprobs.extend(vllm_result["logprobs"]) - completion_text = vllm_result.get("text") or tokenizer.decode( - vllm_result["completion_ids"], skip_special_tokens=True + # Generate completion using trainer (works for both colocate and server modes) + rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] + prompt_ids.extend(rollout_outputs["prompt_ids"]) + completion_ids.extend(rollout_outputs["completion_ids"]) + logprobs.extend(rollout_outputs["logprobs"]) + completion_text = rollout_outputs.get("text") or tokenizer.decode( + rollout_outputs["completion_ids"], skip_special_tokens=True ) + # extract the guess from the completion guess = extract_guess(completion_text) @@ -307,9 +439,9 @@ def rollout_once( feedback = extract_wordle_feedback(observation) # Update guess counts - previous_occurrences = guess_counts[guess] + previous_occurrences = guess_counts.get(guess, 0) repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts)) - guess_counts[guess] += 1 + guess_counts[guess] = previous_occurrences + 1 # calculate custom reward signals from the feedback if not feedback: @@ -391,11 +523,11 @@ trainer = GRPOTrainer( ], train_dataset=dataset, args=grpo_config, - rollout_func=lambda prompts, args, processing_class: rollout_func( + rollout_func=lambda prompts, trainer: rollout_func( env=env, tokenizer=tokenizer, prompts=prompts, - args=args, + trainer=trainer, cli_args=cli_args, system_prompt=system_prompt, ), @@ -405,31 +537,56 @@ trainer.train() ### Running the Advanced Example -The example requires two GPUs: +You can run the Wordle example in either colocate mode (1 GPU) or server mode (2 GPUs): + + + + + +**Colocate mode (1 GPU, recommended)** + +```bash +python examples/scripts/openenv/wordle.py --vllm-mode colocate +``` + +This runs vLLM in the same process as training, requiring only a single GPU. + + + + + +**Server mode (2+ GPUs, scalable)** ```bash # Terminal 1: Start vLLM inference server CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000 # Terminal 2: Run GRPO training with OpenEnv -CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py --vllm-mode server --vllm-server-url http://localhost:8000 ``` -Again, you can manually start the TextArena environment in a Docker container before running the training. -In this case, initialize the client with -`client = TextArenaEnv(base_url="http://0.0.0.0:8001")` -instead of -`client = TextArenaEnv.from_docker_image("registry.hf.space/burtenshaw-textarena:latest")`: +This runs vLLM as a separate server process, useful when you want to: +- Share the inference server across multiple training jobs +- Use multiple GPUs for the vLLM server (via `--tensor-parallel-size`) +- Scale up training to many GPUs while sharing a single inference endpoint + + + + + +You can also manually start the TextArena environment in a Docker container before running the training: ```bash # Launch the TextArena environment docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest ``` +Then connect to it using `--env-mode docker-local--env-host localhost --env-port 8001`. + ### Results -The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters. +The resulting model improves its performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters. -We experimented larger models like `gpt-oss-20b` and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself? \ No newline at end of file +We experimented with larger models like `gpt-oss-20b` and found that the model was able to consistently win the game. However, this requires a lot of compute to train the model. Why not try this out yourself? diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 6467548d8ea..bdc41263013 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -142,7 +142,7 @@ training_args = GRPOConfig( top_p=0.99, top_k=100, temperature=0.99, - num_completions=8, # = num_return_sequences in the paper + num_generations=8, # = num_return_sequences in the paper num_iterations=1, # = ppo_epochs in the paper per_device_train_batch_size=4, gradient_accumulation_steps=32, @@ -232,6 +232,28 @@ trainer = PAPOTrainer( ) ``` +### The Art of Scaling Reinforcement Learning + +**📜 Paper**: https://huggingface.co/papers/2510.13786 + +A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL. + +You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with the following configs: + +```python +from trl import GRPOConfig + +config = GRPOConfig( + loss_type="cispo", + epsilon_high=5.0, + num_generations=16, + scale_rewards="batch", + cast_lm_head_to_fp32=True +) +``` + + + ## Direct Policy Optimization Papers relating to the [`DPOTrainer`] @@ -534,7 +556,7 @@ training_args = RLOOConfig( ## Contrastive Preference Optimization -Papers relating to the [`CPOTrainer`] +Papers relating to the [`experimental.cpo.CPOTrainer`] ### AlphaPO -- Reward shape matters for LLM alignment @@ -543,7 +565,7 @@ Papers relating to the [`CPOTrainer`] AlphaPO is a new Direct Alignment Algorithms (DAAs) method that leverages an alpha-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and over-optimization. To reproduce the paper's setting, use this configuration: ```python -from trl import CPOConfig +from trl.experimental.cpo import CPOConfig # Mistral-Instruct from Table 3 of the paper training_args = CPOConfig( @@ -624,12 +646,12 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data. -To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]: +To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]: ```python -from trl import GKDConfig +from trl.experimental.gkd import GKDConfig -config = GKDConfig( +training_args = GKDConfig( lmbda=1.0, # student produces rollouts for all batches beta=1.0, # to ensure reverse-kl as the loss function teacher_model_name_or_path="teacher-model", # specify the teacher model @@ -649,3 +671,29 @@ config = GOLDConfig( ) ``` + +### Knowledge Distillation of Large Language Models + +**📜 Paper**: https://huggingface.co/papers/2306.08543 + +MiniLLM is the first on-policy knowledge distillation method, which minimizes the sequence-level reverse KLD between the teacher and the student model and is optimized by reinforcement learning. + +It is a generalized version of [Think Machine Lab's On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/), with the option to add distribution-level single-step distillation signals (like GKD when `beta=1`) and long-context reverse KLD signals. + +Alternatively, you can use the [`experimental.MiniLLMTrainer`] and [`experimental.MiniLLMConfig`] to perform MiniLLM distillation as follows: + +```python +from datasets import load_dataset +from trl.experimental.minillm import MiniLLMTrainer + +dataset = load_dataset("trl-lib/tldr", split="train") + +trainer = MiniLLMTrainer( + model="Qwen/Qwen3-0.6B", + teacher_model="Qwen/Qwen3-1.7B", + train_dataset=dataset, +) +trainer.train() +``` + +For more details, see the [MiniLLM Trainer documentation](minillm) documentation. diff --git a/docs/source/peft_integration.md b/docs/source/peft_integration.md index bd196dd99bf..221d9b7071b 100644 --- a/docs/source/peft_integration.md +++ b/docs/source/peft_integration.md @@ -146,7 +146,8 @@ After training your reward adapter and pushing it to the Hub: ```python from peft import LoraConfig -from trl import AutoModelForCausalLMWithValueHead, PPOTrainer +from trl import AutoModelForCausalLMWithValueHead +from trl.experimental.ppo import PPOTrainer model_name = "huggyllama/llama-7b" rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md index 1dabbc4177c..3f7ea2ee73f 100644 --- a/docs/source/ppo_trainer.md +++ b/docs/source/ppo_trainer.md @@ -1,5 +1,11 @@ # PPO Trainer + + +**Deprecation Notice**: PPOTrainer and PPOConfig have been moved to `trl.experimental.ppo` and will be removed from `trl.trainer` in TRL 0.29.0. Please update your imports to use `from trl.experimental.ppo import PPOConfig, PPOTrainer` instead. See [issue #4466](https://github.com/huggingface/trl/issues/4466) for more information. + + + [![model badge](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl) TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347). @@ -228,11 +234,11 @@ python -m openrlbenchmark.rlops_multi_metrics \ ## PPOTrainer -[[autodoc]] PPOTrainer +[[autodoc]] experimental.ppo.PPOTrainer - train - save_model - push_to_hub ## PPOConfig -[[autodoc]] PPOConfig +[[autodoc]] experimental.ppo.PPOConfig diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index 3a89cf55120..6661762af93 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -24,15 +24,12 @@ trainer.train() ```python from trl import GRPOTrainer from datasets import load_dataset - -# Define a simple reward function (count unique chars as example) -def reward_function(completions, **kwargs): - return [len(set(completion.lower())) for completion in completions] +from trl.rewards import accuracy_reward trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", # Start from SFT model - train_dataset=load_dataset("trl-lib/tldr", split="train"), - reward_funcs=reward_function, + train_dataset=load_dataset("trl-lib/DeepMath-103K", split="train"), + reward_funcs=accuracy_reward, ) trainer.train() ``` diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index f258c0a20f8..f92ebb29edb 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -90,6 +90,33 @@ from trl import SFTConfig training_args = SFTConfig(..., packing=True, max_length=512) ``` +## PEFT for parameter-efficient fine-tuning + +Parameter-Efficient Fine-Tuning (PEFT) methods like LoRA are among the most effective techniques for reducing memory usage during training. Instead of training all model parameters, PEFT methods train only a small number of adapter parameters, significantly reducing memory requirements and enabling fine-tuning of larger models on limited hardware. + +For comprehensive details on using PEFT with TRL, including various adapter methods, quantization options, and advanced configurations, see [PEFT Integration](peft_integration). + +To use PEFT for reducing memory usage: + +```python +from datasets import load_dataset +from peft import LoraConfig +from trl import SFTTrainer + +dataset = load_dataset("trl-lib/Capybara", split="train") + +peft_config = LoraConfig() + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + train_dataset=dataset, + peft_config=peft_config, +) +``` + +PEFT can be combined with other memory reduction techniques such as quantization (4-bit or 8-bit) for even greater memory savings. See [PEFT Integration](peft_integration) for quantization examples. + + ## Liger for reducing peak memory usage > [Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. @@ -138,7 +165,7 @@ training_args = KTOConfig(..., use_liger_kernel=True) ```python -from trl import GKDConfig +from trl.experimental.gkd import GKDConfig training_args = GKDConfig(..., use_liger_kernel=True) ``` @@ -161,7 +188,7 @@ Padding-free batching is an alternative approach for reducing memory usage. In t ```python from trl import DPOConfig -training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` @@ -170,7 +197,7 @@ training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_imple ```python from trl import SFTConfig -training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) ``` @@ -247,7 +274,7 @@ training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False) ```python -from trl import PPOConfig +from trl.experimental.ppo import PPOConfig training_args = PPOConfig(..., ds3_gather_for_generation=False) ``` @@ -290,3 +317,5 @@ training_args = RLOOConfig(..., vllm_enable_sleep_mode=True) + +Offloading the vLLM weights and cache helps keep GPU memory usage low, which can be particularly beneficial when training large models or using limited GPU resources. However, waking the vLLM engine from sleep mode introduces some host–device transfer latency, which may slightly impact training speed. diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 36d315e678d..68173d218da 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -15,10 +15,10 @@ This post-training method was contributed by [Costa Huang](https://github.com/vw ## Quick start -This example demonstrates how to train a model using the RLOO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here: +This example demonstrates how to train a model using the RLOO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [DeepMath-103K dataset](https://huggingface.co/datasets/trl-lib/DeepMath-103K). You can view the data in the dataset here: