diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a4ca28675bc..7a6f7ee497c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -64,8 +64,6 @@ title: GRPO - local: kto_trainer title: KTO - - local: orpo_trainer - title: ORPO - local: prm_trainer title: PRM - local: reward_trainer @@ -115,6 +113,8 @@ title: MiniLLM - local: nash_md_trainer title: Nash-MD + - local: orpo_trainer + title: ORPO - local: papo_trainer title: PAPO - local: ppo_trainer diff --git a/docs/source/community_tutorials.md b/docs/source/community_tutorials.md index 333aa973b5a..a412ee7d917 100644 --- a/docs/source/community_tutorials.md +++ b/docs/source/community_tutorials.md @@ -15,7 +15,7 @@ Community tutorials are made by active members of the Hugging Face community who | Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) | | Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) | | Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) | -| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | +| Preference Optimization | [`experimental.orpo.ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | | Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) | ### Videos diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index 8faf2f11fc7..5f86257875d 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -387,20 +387,20 @@ Choosing the right dataset type depends on the task you are working on and the s | Trainer | Expected dataset type | | --- | --- | -| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | -| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) | | [`GRPOTrainer`] | [Prompt-only](#prompt-only) | | [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | -| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) | | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | -| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling | | [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) | | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | [`RLOOTrainer`] | [Prompt-only](#prompt-only) | | [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) | +| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | +| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) | +| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling | | [`experimental.xpo.XPOTrainer`] | [Prompt-only](#prompt-only) | ## Using any dataset with TRL: preprocessing and conversion diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 6fe254f326e..598f8150bdd 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -58,7 +58,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/openenv/catch.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/catch.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Catch environment (OpenSpiel) and vLLM | | [`examples/scripts/openenv/echo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Echo environment and vLLM. | | [`examples/scripts/openenv/wordle.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/wordle.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Wordle environment and vLLM. | -| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | +| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`experimental.orpo.ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. | | [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | | [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). | diff --git a/docs/source/index.md b/docs/source/index.md index 95f964b671a..01f20cf5c0b 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -41,10 +41,10 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL - [`SFTTrainer`] - [`DPOTrainer`] -- [`ORPOTrainer`] - [`KTOTrainer`] - [`experimental.bco.BCOTrainer`] 🧪 - [`experimental.cpo.CPOTrainer`] 🧪 +- [`experimental.orpo.ORPOTrainer`] 🧪 ### Knowledge distillation diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index 3092de2d2ae..555f0858316 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -34,7 +34,7 @@ Below is the script to train the model: ```python # train_orpo.py from datasets import load_dataset -from trl import ORPOConfig, ORPOTrainer +from trl.experimental.orpo import ORPOConfig, ORPOTrainer from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") @@ -79,9 +79,9 @@ Here are some other factors to consider when choosing a programming language for ## Expected dataset type -ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +ORPO requires a [preference dataset](dataset_formats#preference). The [`experimental.orpo.ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. -Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. +Although the [`experimental.orpo.ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. ## Example script @@ -121,11 +121,11 @@ While training and evaluating, we record the following reward metrics: ## ORPOTrainer -[[autodoc]] ORPOTrainer +[[autodoc]] experimental.orpo.ORPOTrainer - train - save_model - push_to_hub ## ORPOConfig -[[autodoc]] ORPOConfig +[[autodoc]] experimental.orpo.ORPOConfig diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index e256a4277ad..36eb32a565b 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -63,7 +63,8 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config +from trl import ModelConfig, ScriptArguments, get_peft_config +from trl.experimental.orpo import ORPOConfig, ORPOTrainer # Enable logging in a Hugging Face Space diff --git a/tests/test_orpo_trainer.py b/tests/experimental/test_orpo_trainer.py similarity index 98% rename from tests/test_orpo_trainer.py rename to tests/experimental/test_orpo_trainer.py index 70f087ac948..95a82234909 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/experimental/test_orpo_trainer.py @@ -17,9 +17,9 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from trl import ORPOConfig, ORPOTrainer +from trl.experimental.orpo import ORPOConfig, ORPOTrainer -from .testing_utils import TrlTestCase, require_peft +from ..testing_utils import TrlTestCase, require_peft class TestORPOTrainer(TrlTestCase): diff --git a/trl/experimental/orpo/__init__.py b/trl/experimental/orpo/__init__.py new file mode 100644 index 00000000000..17960ce5b18 --- /dev/null +++ b/trl/experimental/orpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .orpo_config import ORPOConfig +from .orpo_trainer import ORPOTrainer + + +__all__ = ["ORPOConfig", "ORPOTrainer"] diff --git a/trl/experimental/orpo/orpo_config.py b/trl/experimental/orpo/orpo_config.py new file mode 100644 index 00000000000..1db0cf94086 --- /dev/null +++ b/trl/experimental/orpo/orpo_config.py @@ -0,0 +1,179 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class ORPOConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.orpo.ORPOTrainer`]. + + This class includes only the parameters that are specific to ORPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. This argument is required if you want to use the default data collator. + max_completion_length (`int`, *optional*): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the relative ratio loss weight in the ORPO loss. In the + [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the + [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, *optional*): + Padding value to use. If `None`, the padding value of the tokenizer is used. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from the model to W&B or Comet during evaluation. + is_encoder_decoder (`bool`, *optional*): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322, but the fix has not yet been released. We + # add a temporary workaround here, which can be removed once the fix is available—likely in Transformers 4.57.2. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_prompt_length: int | None = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + max_completion_length: int | None = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is " + "denoted by λ." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + label_pad_token_id: int = field( + default=-100, + metadata={ + "help": "Label pad token id. This argument is required if you want to use the default data collator." + }, + ) + padding_value: int | None = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, + ) + is_encoder_decoder: bool | None = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " + "argument, you need to specify if the model returned by the callable is an encoder-decoder model." + }, + ) + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/trl/experimental/orpo/orpo_trainer.py b/trl/experimental/orpo/orpo_trainer.py new file mode 100644 index 00000000000..9e3b1987eba --- /dev/null +++ b/trl/experimental/orpo/orpo_trainer.py @@ -0,0 +1,1049 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import random +import textwrap +from collections import defaultdict +from collections.abc import Callable +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState, logging +from datasets import Dataset +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + is_comet_available, + is_torch_xla_available, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_peft_available, is_torch_fx_proxy + +from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + DPODataCollatorWithPadding, + add_bos_token_if_needed, + add_eos_token_if_needed, + disable_dropout_in_model, + log_table_to_comet_experiment, + pad_to_length, + peft_module_casting_to_bf16, + selective_log_softmax, +) +from .orpo_config import ORPOConfig + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +logger = logging.get_logger(__name__) + + +class ORPOTrainer(BaseTrainer): + r""" + Initialize ORPOTrainer. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. + args ([`experimental.orpo.ORPOConfig`]): + The ORPO config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the + sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be + used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in + a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + """ + + _tag_names = ["trl", "orpo"] + _name = "ORPO" + _paper = { + "title": "ORPO: Monolithic Preference Optimization without Reference Model", + "id": "2403.07691", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{hong2024orpo, + title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, + author = {Jiwoo Hong and Noah Lee and James Thorne}, + year = 2024, + eprint = {arXiv:2403.07691} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + args: ORPOConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + model_init: Callable[[], PreTrainedModel] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, + ): + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + dtype = model_init_kwargs.get("dtype", "auto") + if dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." + ) + model_init_kwargs["dtype"] = dtype + model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + if self.is_encoder_decoder: + self.decoder_start_token_id = model.config.decoder_start_token_id + self.pad_token_id = model.config.pad_token_id + + if processing_class is None: + raise ValueError("processing_class must be specified to tokenize a ORPO dataset.") + if args.max_length is None: + logger.warning( + "`max_length` is not set in the ORPOConfig's init" + " it will default to `512` by default, but you should do it yourself in the future.", + ) + max_length = 512 + else: + max_length = args.max_length + if args.max_prompt_length is None: + logger.warning( + "`max_prompt_length` is not set in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + ) + max_prompt_length = 128 + else: + max_prompt_length = args.max_prompt_length + + if args.max_completion_length is None and self.is_encoder_decoder: + logger.warning( + "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + ) + self.max_completion_length = 128 + else: + self.max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + label_pad_token_id=args.label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + logger.warning( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = args.truncation_mode + self.processing_class = processing_class + + self.beta = args.beta + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + logger.warning( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + ) + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed, and apply the chat template if needed + train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + + b)[len(enc(a)):]`. Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict: + """Tokenize a single row from a ORPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, + we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length + of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.processing_class(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + a != b + for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True) + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + # add BOS token to head of prompt. Avoid adding if it's already there + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.processing_class.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens, + ) + + # add EOS token to end of answer. Avoid adding if it's already there + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.processing_class.eos_token_id, chosen_tokens, rejected_tokens + ) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.processing_class( + chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + rejected_tokens = self.processing_class( + rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + prompt_tokens = self.processing_class( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["rejected_labels"]) + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["chosen_labels"]) + ) + + if is_torch_xla_available(): + # Pad the sequences to global max_length to avoid TorchXLA recompilation + for k in batch: + if "labels" in k or self.is_encoder_decoder: + pad_value = self.label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = self.padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) + return batch + + @staticmethod + def concatenated_inputs( + batch: dict[str, list | torch.LongTensor], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: torch.device | None = None, + ) -> dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: + A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors + of shape (batch_size, sequence_length). + is_encoder_decoder: + Whether the model is an encoder-decoder model. + label_pad_token_id: + The label pad token id. + padding_value: + The padding value to use for the concatenated inputs_ids. + device: + The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def odds_ratio_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: + Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: + Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO + loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for + the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the + rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes. + """ + + # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + losses = self.beta * ratio + + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: + Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are + ignored. Shape: (batch_size, sequence_length) + average_log_prob: + If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the + log probabilities of the (non-masked) tokens. + label_pad_token_id: The label pad token id. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the + given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == label_pad_token_id, 0, labels) + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + # orpo chosen nll loss is computed over the full prompt and response + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1, :] + rejected_logits = all_logits[len_chosen:, :-1, :] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, list | torch.LongTensor], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( + chosen_rewards - rejected_rewards + ).mean() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() + metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( + policy_rejected_logits.detach().mean() + ).mean() + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( + policy_chosen_logits.detach().mean() + ).mean() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() + metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() + metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() + if is_torch_xla_available(): + xm.mark_step() # needed because .item() calls + for k, v in metrics.items(): + metrics[k] = v.item() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs=False, + num_items_in_batch=None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + return policy_output_decoded + + def prediction_step( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ): + if not self.use_dpo_data_collator: + logger.warning( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: bool | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by + `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded = self.generate_from_model(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy"], + data=[ + [prompt, pol[len(prompt) :]] + for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float`, *optional*): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + def _shift_right(self, input_ids): + if self.decoder_start_token_id is None: + raise ValueError( + "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + if self.pad_token_id is None: + raise ValueError("model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) + + return shifted_input_ids + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index ffc63516b2e..e5f30575fd7 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -12,168 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field -from typing import Any +import warnings +from dataclasses import dataclass -from transformers import TrainingArguments +from ..experimental.orpo import ORPOConfig as _ORPOConfig @dataclass -class ORPOConfig(TrainingArguments): - r""" - Configuration class for the [`ORPOTrainer`]. - - This class includes only the parameters that are specific to ORPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int`, *optional*): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the relative ratio loss weight in the ORPO loss. In the - [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the - [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int`, *optional*): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from the model to W&B or Comet during evaluation. - is_encoder_decoder (`bool`, *optional*): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - """ - - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] - - # Parameters whose default values are overridden from TrainingArguments - learning_rate: float = field( - default=1e-6, - metadata={"help": "The initial learning rate for AdamW."}, - ) - logging_steps: float = field( - default=10, - metadata={ - "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " - "will be interpreted as ratio of total training steps." - }, - ) - gradient_checkpointing: bool = field( - default=True, - metadata={ - "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." - }, - ) - bf16: bool | None = field( - default=None, - metadata={ - "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " - "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " - "`fp16` is not set." - }, - ) - # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue - # was fixed in https://github.com/huggingface/transformers/pull/41322, but the fix has not yet been released. We - # add a temporary workaround here, which can be removed once the fix is available—likely in Transformers 4.57.2. - lr_scheduler_kwargs: dict | str | None = field( - default=None, - metadata={ - "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " - "restarts." - }, - ) - - max_length: int | None = field( - default=1024, - metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, - ) - max_prompt_length: int | None = field( - default=512, - metadata={ - "help": "Maximum length of the prompt. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, - ) - max_completion_length: int | None = field( - default=None, - metadata={ - "help": "Maximum length of the completion. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, - ) - beta: float = field( - default=0.1, - metadata={ - "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is " - "denoted by λ." - }, - ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether to disable dropout in the model."}, - ) - label_pad_token_id: int = field( - default=-100, - metadata={ - "help": "Label pad token id. This argument is required if you want to use the default data collator." - }, - ) - padding_value: int | None = field( - default=None, - metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, - ) - truncation_mode: str = field( - default="keep_end", - metadata={ - "help": "Truncation mode to use when the prompt is too long.", - "choices": ["keep_end", "keep_start"], - }, - ) - generate_during_eval: bool = field( - default=False, - metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, - ) - is_encoder_decoder: bool | None = field( - default=None, - metadata={ - "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " - "argument, you need to specify if the model returned by the callable is an encoder-decoder model." - }, - ) - model_init_kwargs: dict[str, Any] | None = field( - default=None, - metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " - "from a string." - }, - ) - dataset_num_proc: int | None = field( - default=None, - metadata={"help": "Number of processes to use for processing the dataset."}, - ) - +class ORPOConfig(_ORPOConfig): def __post_init__(self): - self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 - + warnings.warn( + "The `ORPOConfig` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.orpo import ORPOConfig`. The current import path will be removed and no longer " + "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223." + ) super().__post_init__() diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index dd8d3a7eb4b..b91a59eef35 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -12,1047 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import os -import random -import textwrap import warnings -from collections import defaultdict -from collections.abc import Callable -from contextlib import nullcontext -from pathlib import Path -from typing import Any, Literal +from dataclasses import dataclass -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from accelerate import PartialState, logging -from datasets import Dataset -from torch import autocast -from torch.utils.data import DataLoader -from transformers import ( - AutoModelForCausalLM, - BaseImageProcessor, - DataCollator, - FeatureExtractionMixin, - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, - is_comet_available, - is_torch_xla_available, - is_wandb_available, -) -from transformers.trainer_callback import TrainerCallback -from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_fx_proxy +from ..experimental.orpo import ORPOTrainer as _ORPOTrainer -from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt -from .base_trainer import BaseTrainer -from .orpo_config import ORPOConfig -from .utils import ( - DPODataCollatorWithPadding, - add_bos_token_if_needed, - add_eos_token_if_needed, - disable_dropout_in_model, - log_table_to_comet_experiment, - pad_to_length, - peft_module_casting_to_bf16, - selective_log_softmax, -) - -if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training - - -if is_wandb_available(): - import wandb - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - -logger = logging.get_logger(__name__) - - -class ORPOTrainer(BaseTrainer): - r""" - Initialize ORPOTrainer. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. - args ([`ORPOConfig`]): - The ORPO config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - """ - - _tag_names = ["trl", "orpo"] - _name = "ORPO" - _paper = { - "title": "ORPO: Monolithic Preference Optimization without Reference Model", - "id": "2403.07691", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{hong2024orpo, - title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, - author = {Jiwoo Hong and Noah Lee and James Thorne}, - year = 2024, - eprint = {arXiv:2403.07691} - }"""), - } - - def __init__( - self, - model: PreTrainedModel | nn.Module | str | None = None, - args: ORPOConfig | None = None, - data_collator: DataCollator | None = None, - train_dataset: Dataset | None = None, - eval_dataset: Dataset | dict[str, Dataset] | None = None, - processing_class: PreTrainedTokenizerBase - | BaseImageProcessor - | FeatureExtractionMixin - | ProcessorMixin - | None = None, - model_init: Callable[[], PreTrainedModel] | None = None, - callbacks: list[TrainerCallback] | None = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - peft_config: dict | None = None, - compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, - ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if args.model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - dtype = model_init_kwargs.get("dtype", "auto") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - model_init_kwargs["dtype"] = dtype - model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") - - if isinstance(model, str): - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # get peft model with the given config - model = get_peft_model(model, peft_config) - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): - raise ValueError( - "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve." - ) - - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder - - if self.is_encoder_decoder: - self.decoder_start_token_id = model.config.decoder_start_token_id - self.pad_token_id = model.config.pad_token_id - - if processing_class is None: - raise ValueError("processing_class must be specified to tokenize a ORPO dataset.") - if args.max_length is None: - logger.warning( - "`max_length` is not set in the ORPOConfig's init" - " it will default to `512` by default, but you should do it yourself in the future.", - ) - max_length = 512 - else: - max_length = args.max_length - if args.max_prompt_length is None: - logger.warning( - "`max_prompt_length` is not set in the ORPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - else: - max_prompt_length = args.max_prompt_length - - if args.max_completion_length is None and self.is_encoder_decoder: - logger.warning( - "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - self.max_completion_length = 128 - else: - self.max_completion_length = args.max_completion_length - - if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - logger.warning( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" - " we have set it for you, but you should do it yourself in the future.", - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - - self.max_length = max_length - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.processing_class = processing_class - - self.beta = args.beta - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - logger.warning( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - ) - - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and - # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens - # of the input, floating-point operations will not be computed." To suppress this warning, we set the - # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate - # that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().main_process_first(): - # Extract the prompt if needed, and apply the chat template if needed - train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - train_dataset = train_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc - ) - train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - if eval_dataset is not None: - eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - ) - eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - def build_tokenized_answer(self, prompt, answer): - """ - Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + - b)[len(enc(a)):]`. Reference: - https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - """ - - full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) - prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] - - answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] - answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] - - # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` - full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) - - # Prepare input tokens for token by token comparison - full_input_ids = np.array(full_tokenized["input_ids"]) - - if len(full_input_ids) != len(full_concat_input_ids): - raise ValueError("Prompt input ids and answer input ids should have the same length.") - - # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens - # can be merged together when tokenizing prompt+answer. This could result - # on the last token from the prompt being different when tokenized on its own - # vs when done as prompt+answer. - response_token_ids_start_idx = len(prompt_input_ids) - - # If tokenized prompt is different than both prompt+answer, then it means the - # last token has changed due to merging. - if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: - response_token_ids_start_idx -= 1 - - prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] - prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] - - if len(prompt_input_ids) != len(prompt_attention_mask): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] - answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] - - return dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - input_ids=answer_input_ids, - attention_mask=answer_attention_mask, +@dataclass +class ORPOTrainer(_ORPOTrainer): + def __init__(self, *args, **kwargs): + warnings.warn( + "The `ORPOTrainer` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.orpo import ORPOTrainer`. The current import path will be removed and no longer " + "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223." ) - - def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict: - """Tokenize a single row from a ORPO specific dataset. - - At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + - chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, - we truncate the chosen/rejected. - - We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length - of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. - """ - batch = {} - prompt = feature["prompt"] - chosen = feature["chosen"] - rejected = feature["rejected"] - - if not self.is_encoder_decoder: - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") - prompt_tokens = self.processing_class(prompt, add_special_tokens=False) - prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} - - if not isinstance(chosen, str): - raise ValueError(f"chosen should be an str but got {type(chosen)}") - chosen_tokens = self.build_tokenized_answer(prompt, chosen) - - if not isinstance(rejected, str): - raise ValueError(f"rejected should be an str but got {type(rejected)}") - rejected_tokens = self.build_tokenized_answer(prompt, rejected) - - # Last prompt token might get merged by tokenizer and - # it should not be included for generation if that happens - prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) - - chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) - rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) - prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) - - for k, v in prompt_tokens.items(): - prompt_tokens[k] = v[:prompt_len_input_ids] - - # Make sure prompts only have one different token at most an - # and length only differs by 1 at most - num_diff_tokens = sum( - a != b - for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True) - ) - num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) - if num_diff_tokens > 1 or num_diff_len > 1: - raise ValueError( - "Chosen and rejected prompt_input_ids might only differ on the " - "last token due to tokenizer merge ops." - ) - - # add BOS token to head of prompt. Avoid adding if it's already there - prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( - self.processing_class.bos_token_id, - prompt_len_input_ids, - prompt_tokens, - chosen_prompt_len_input_ids, - chosen_tokens, - rejected_prompt_len_input_ids, - rejected_tokens, - ) - - # add EOS token to end of answer. Avoid adding if it's already there - chosen_tokens, rejected_tokens = add_eos_token_if_needed( - self.processing_class.eos_token_id, chosen_tokens, rejected_tokens - ) - - longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) - - # if combined sequence is too long, truncate the prompt - for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - if self.truncation_mode == "keep_start": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] - elif self.truncation_mode == "keep_end": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] - else: - raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") - - # if that's still too long, truncate the response - for answer_tokens in [chosen_tokens, rejected_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - for k in ["input_ids", "attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] - - # Create labels - chosen_sequence_tokens = { - k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] - } - rejected_sequence_tokens = { - k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] - } - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] - chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(chosen_tokens["prompt_input_ids"]) - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] - rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(rejected_tokens["prompt_input_ids"]) - - for k, toks in { - "chosen_": chosen_sequence_tokens, - "rejected_": rejected_sequence_tokens, - "": prompt_tokens, - }.items(): - for type_key, tokens in toks.items(): - if type_key == "token_type_ids": - continue - batch[f"{k}{type_key}"] = tokens - - else: - chosen_tokens = self.processing_class( - chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - rejected_tokens = self.processing_class( - rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - prompt_tokens = self.processing_class( - prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True - ) - - batch["chosen_labels"] = chosen_tokens["input_ids"] - batch["rejected_labels"] = rejected_tokens["input_ids"] - batch["prompt_input_ids"] = prompt_tokens["input_ids"] - batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] - - if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): - batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["rejected_labels"]) - ) - batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["chosen_labels"]) - ) - - if is_torch_xla_available(): - # Pad the sequences to global max_length to avoid TorchXLA recompilation - for k in batch: - if "labels" in k or self.is_encoder_decoder: - pad_value = self.label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = self.padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) - return batch - - @staticmethod - def concatenated_inputs( - batch: dict[str, list | torch.LongTensor], - is_encoder_decoder: bool = False, - label_pad_token_id: int = -100, - padding_value: int = 0, - device: torch.device | None = None, - ) -> dict[str, torch.LongTensor]: - """Concatenate the chosen and rejected inputs into a single tensor. - - Args: - batch: - A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors - of shape (batch_size, sequence_length). - is_encoder_decoder: - Whether the model is an encoder-decoder model. - label_pad_token_id: - The label pad token id. - padding_value: - The padding value to use for the concatenated inputs_ids. - device: - The device for the concatenated inputs. - - Returns: - A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. - """ - concatenated_batch = {} - - if is_encoder_decoder: - max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) - else: - max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - - for k in batch: - if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("chosen", "concatenated") - concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) - for k in batch: - if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("rejected", "concatenated") - concatenated_batch[concatenated_key] = torch.cat( - ( - concatenated_batch[concatenated_key], - pad_to_length(batch[k], max_length, pad_value=pad_value), - ), - dim=0, - ).to(device=device) - - if is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) - concatenated_batch["concatenated_attention_mask"] = ( - batch["prompt_attention_mask"].repeat(2, 1).to(device=device) - ) - - return concatenated_batch - - def odds_ratio_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: - Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: - Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO - loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for - the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the - rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes. - """ - - # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) - log_odds = (policy_chosen_logps - policy_rejected_logps) - ( - torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) - ) - ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio - - chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() - rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() - - return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) - - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: - Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are - ignored. Shape: (batch_size, sequence_length) - average_log_prob: - If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the - log probabilities of the (non-masked) tokens. - label_pad_token_id: The label pad token id. - is_encoder_decoder: Whether the model is an encoder-decoder model. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the - given logits. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels = torch.where(labels == label_pad_token_id, 0, labels) - - per_token_logps = selective_log_softmax(logits, labels) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def concatenated_forward( - self, model: nn.Module, batch: dict[str, list | torch.LongTensor] - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - """ - concatenated_batch = self.concatenated_inputs( - batch, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - padding_value=self.padding_value, - device=self.accelerator.device, - ) - len_chosen = batch["chosen_labels"].shape[0] - - model_kwargs = ( - { - "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), - } - if self.is_encoder_decoder - else {} - ) - - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) - all_logits = outputs.logits - - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - if self.is_encoder_decoder: - labels = concatenated_batch["concatenated_labels"].clone() - else: - labels = concatenated_batch["concatenated_input_ids"].clone() - attention_mask = concatenated_batch["concatenated_attention_mask"] - labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) - # orpo chosen nll loss is computed over the full prompt and response - chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - - all_logps = self.get_batch_logps( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=True, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - if not self.is_encoder_decoder: - chosen_logits = all_logits[:len_chosen, :-1, :] - rejected_logits = all_logits[len_chosen:, :-1, :] - else: - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] - - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) - - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) - - def get_batch_loss_metrics( - self, - model, - batch: dict[str, list | torch.LongTensor], - train_eval: Literal["train", "eval"] = "train", - ): - """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - - forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( - policy_chosen_logps, policy_rejected_logps - ) - # full ORPO loss - loss = policy_nll_loss - losses.mean() - - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() - metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( - chosen_rewards - rejected_rewards - ).mean() - metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() - metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() - metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( - policy_rejected_logits.detach().mean() - ).mean() - metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( - policy_chosen_logits.detach().mean() - ).mean() - metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() - metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() - metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() - if is_torch_xla_available(): - xm.mark_step() # needed because .item() calls - for k, v in metrics.items(): - metrics[k] = v.item() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss - - return loss, metrics - - def compute_loss( - self, - model: PreTrainedModel | nn.Module, - inputs: dict[str, torch.Tensor | Any], - return_outputs=False, - num_items_in_batch=None, - ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") - - # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: - loss = loss.to(self.args.device) - - # force log the metrics - self.store_metrics(metrics, train_eval="train") - - if return_outputs: - return (loss, metrics) - return loss - - def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - return policy_output_decoded - - def prediction_step( - self, - model: PreTrainedModel | nn.Module, - inputs: dict[str, torch.Tensor | Any], - prediction_loss_only: bool, - ignore_keys: list[str] | None = None, - ): - if not self.use_dpo_data_collator: - logger.warning( - "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " - "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" - ) - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") - - # force log the metrics - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return (loss.detach(), None, None) - - # logits for the chosen and rejected samples from model - logits_dict = { - "eval_logits/chosen": metrics["eval_logits/chosen"], - "eval_logits/rejected": metrics["eval_logits/rejected"], - } - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: bool | None = None, - ignore_keys: list[str] | None = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded = self.generate_from_model(self.model, random_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy"], - data=[ - [prompt, pol[len(prompt) :]] - for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True) - ], - ) - if "wandb" in self.args.report_to: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output - - def log(self, logs: dict[str, float], start_time: float | None = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float`, *optional*): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) - - def _shift_right(self, input_ids): - if self.decoder_start_token_id is None: - raise ValueError( - "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) - shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = self.decoder_start_token_id - - if self.pad_token_id is None: - raise ValueError("model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) - - return shifted_input_ids - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) + super().__init__(*args, **kwargs)