diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f3439f04270..6fc8e2ea783 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -64,8 +64,6 @@ title: GRPO - local: kto_trainer title: KTO - - local: nash_md_trainer - title: Nash-MD - local: orpo_trainer title: ORPO - local: ppo_trainer @@ -117,6 +115,8 @@ title: Judges - local: minillm title: MiniLLM + - local: nash_md_trainer + title: Nash-MD - local: papo_trainer title: PAPO - local: xpo_trainer diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index 10d527eea9c..3444742987f 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -393,7 +393,7 @@ Choosing the right dataset type depends on the task you are working on and the s | [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) | | [`GRPOTrainer`] | [Prompt-only](#prompt-only) | | [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | -| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | +| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) | | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | | [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`PPOTrainer`] | Tokenized language modeling | diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index b8d95ec336c..2e1ea944187 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -51,7 +51,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. | | [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. | | [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. | -| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. | +| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | | [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | diff --git a/docs/source/index.md b/docs/source/index.md index f216434a885..1ac0d5d7321 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -25,8 +25,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL - [`GRPOTrainer`] ⚡️ - [`RLOOTrainer`] ⚡️ - [`OnlineDPOTrainer`] ⚡️ -- [`NashMDTrainer`] ⚡️ - [`PPOTrainer`] +- [`experimental.nash_md.NashMDTrainer`] 🧪 ⚡️ - [`experimental.xpo.XPOTrainer`] 🧪 ⚡️ ### Reward modeling diff --git a/docs/source/nash_md_trainer.md b/docs/source/nash_md_trainer.md index 93df716d410..e86592769a5 100644 --- a/docs/source/nash_md_trainer.md +++ b/docs/source/nash_md_trainer.md @@ -28,8 +28,8 @@ Below is the script to train the model: ```python # train_nash_md.py from datasets import load_dataset -from trl import NashMDConfig, NashMDTrainer from trl.experimental.judges import PairRMJudge +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") @@ -64,7 +64,7 @@ The best programming language depends on personal preference, the complexity of ## Expected dataset type -Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`experimental.nash_md.NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ## Usage tips @@ -91,7 +91,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht ### Encourage EOS token generation -We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]: +We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.nash_md.NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.nash_md.NashMDConfig`]: ```python training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) @@ -144,16 +144,16 @@ While training and evaluating, we record the following reward metrics: * `logps/rejected`: The mean log probabilities of the reference completions. * `val/model_contain_eos_token`: The amount of times the model's output contains the eos token. * `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token. -* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`]. -* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`]. +* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.nash_md.NashMDConfig`]. +* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.nash_md.NashMDConfig`]. ## NashMDTrainer -[[autodoc]] NashMDTrainer +[[autodoc]] experimental.nash_md.NashMDTrainer - train - save_model - push_to_hub ## NashMDConfig -[[autodoc]] NashMDConfig +[[autodoc]] experimental.nash_md.NashMDConfig diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index 9a87d91f313..c4618752239 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -10,9 +10,9 @@ This document will guide you through the process of using vLLM with TRL for fast > > - [`GRPOTrainer`] > - [`OnlineDPOTrainer`] -> - [`NashMDTrainer`] -> - [`experimental.xpo.XPOTrainer`] > - [`RLOOTrainer`] +> - [`experimental.nash_md.NashMDTrainer`] +> - [`experimental.xpo.XPOTrainer`] ## 🚀 How can I use vLLM with TRL to speed up training? @@ -105,7 +105,7 @@ trainer.train() ```python from datasets import load_dataset -from trl import NashMDTrainer, NashMDConfig +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer dataset = load_dataset("trl-lib/tldr", split="train") @@ -379,7 +379,7 @@ training_args = OnlineDPOConfig( ```python -from trl import NashMDConfig +from trl.experimental.nash_md import NashMDConfig training_args = NashMDConfig( ..., @@ -454,7 +454,7 @@ training_args = OnlineDPOConfig( ```python -from trl import NashMDConfig +from trl.experimental.nash_md import NashMDConfig training_args = NashMDConfig( ..., diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 2f9fa870762..ac461b1802a 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -63,14 +63,13 @@ from trl import ( LogCompletionsCallback, ModelConfig, - NashMDConfig, - NashMDTrainer, ScriptArguments, TrlParser, get_kbit_device_map, get_quantization_config, ) from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer # Enable logging in a Hugging Face Space diff --git a/tests/test_nash_md_trainer.py b/tests/experimental/test_nash_md_trainer.py similarity index 98% rename from tests/test_nash_md_trainer.py rename to tests/experimental/test_nash_md_trainer.py index 7e7449e0fe5..1f4ec1d255b 100644 --- a/tests/test_nash_md_trainer.py +++ b/tests/experimental/test_nash_md_trainer.py @@ -17,9 +17,9 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from transformers.utils import is_peft_available -from trl import NashMDConfig, NashMDTrainer +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft +from ..testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft if is_peft_available(): diff --git a/tests/experimental/test_trainers_args.py b/tests/experimental/test_trainers_args.py index 6384b228d11..09cb448b296 100644 --- a/tests/experimental/test_trainers_args.py +++ b/tests/experimental/test_trainers_args.py @@ -18,6 +18,7 @@ from trl.experimental.bco import BCOConfig, BCOTrainer from trl.experimental.cpo import CPOConfig, CPOTrainer +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer from trl.experimental.xpo import XPOConfig, XPOTrainer from ..testing_utils import TrlTestCase, require_sklearn @@ -113,6 +114,28 @@ def test_cpo(self): assert trainer.args.model_init_kwargs == {"trust_remote_code": True} assert trainer.args.dataset_num_proc == 4 + @pytest.mark.parametrize("mixtures_coef_list", [False, True]) + def test_nash_md(self, mixtures_coef_list): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1) + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = NashMDConfig( + self.tmp_dir, + mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6], + ) + trainer = NashMDTrainer( + args=training_args, + processing_class=tokenizer, + model=model, + ref_model=ref_model, + reward_funcs=reward_model, + train_dataset=dataset, + ) + assert trainer.args.mixture_coef == (0.5 if not mixtures_coef_list else [0.5, 0.6]) + @pytest.mark.parametrize("alpha_list", [False, True]) def test_xpo(self, alpha_list): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 19679f93d9d..2097008b49b 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -22,8 +22,6 @@ FDivergenceType, KTOConfig, KTOTrainer, - NashMDConfig, - NashMDTrainer, OnlineDPOConfig, OnlineDPOTrainer, ORPOConfig, @@ -150,28 +148,6 @@ def test_kto(self): assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} assert trainer.args.dataset_num_proc == 4 - @pytest.mark.parametrize("mixtures_coef_list", [False, True]) - def test_nash_md(self, mixtures_coef_list): - model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" - tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForCausalLM.from_pretrained(model_id) - ref_model = AutoModelForCausalLM.from_pretrained(model_id) - reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1) - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") - training_args = NashMDConfig( - self.tmp_dir, - mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6], - ) - trainer = NashMDTrainer( - args=training_args, - processing_class=tokenizer, - model=model, - ref_model=ref_model, - reward_funcs=reward_model, - train_dataset=dataset, - ) - assert trainer.args.mixture_coef == (0.5 if not mixtures_coef_list else [0.5, 0.6]) - @pytest.mark.parametrize("beta_list", [False, True]) def test_online_dpo(self, beta_list): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" diff --git a/trl/experimental/nash_md/__init__.py b/trl/experimental/nash_md/__init__.py new file mode 100644 index 00000000000..9369b5312ba --- /dev/null +++ b/trl/experimental/nash_md/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .nash_md_config import NashMDConfig +from .nash_md_trainer import NashMDTrainer + + +__all__ = ["NashMDConfig", "NashMDTrainer"] diff --git a/trl/experimental/nash_md/nash_md_config.py b/trl/experimental/nash_md/nash_md_config.py new file mode 100644 index 00000000000..0f74236eba9 --- /dev/null +++ b/trl/experimental/nash_md/nash_md_config.py @@ -0,0 +1,46 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ...trainer.online_dpo_config import OnlineDPOConfig + + +@dataclass +class NashMDConfig(OnlineDPOConfig): + r""" + Configuration class for the [`experimental.nash_md.NashMDTrainer`]. + + Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: + + Parameters: + mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): + Logit mixture coefficient for the model and reference model. If a list of floats is provided then the + mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the + epochs. + """ + + mixture_coef: list[float] = field( + default_factory=lambda: [0.5], + metadata={ + "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " + "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " + "rest of the epochs." + }, + ) + + def __post_init__(self): + super().__post_init__() + if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: + self.mixture_coef = self.mixture_coef[0] diff --git a/trl/experimental/nash_md/nash_md_trainer.py b/trl/experimental/nash_md/nash_md_trainer.py new file mode 100644 index 00000000000..845354b0d69 --- /dev/null +++ b/trl/experimental/nash_md/nash_md_trainer.py @@ -0,0 +1,489 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap +from collections.abc import Callable +from typing import Any + +import jinja2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset, IterableDataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import EvalPrediction +from transformers.training_args import OptimizerNames +from transformers.utils import is_peft_available + +from ...data_utils import is_conversational, maybe_apply_chat_template +from ...models.modeling_base import GeometricMixtureWrapper +from ...models.utils import unwrap_model_for_generation +from ...trainer.judges import BasePairwiseJudge +from ...trainer.online_dpo_trainer import OnlineDPOTrainer +from ...trainer.utils import SIMPLE_CHAT_TEMPLATE, empty_cache, get_reward, selective_log_softmax, truncate_right +from .nash_md_config import NashMDConfig + + +if is_peft_available(): + from peft import PeftModel + + +class NashMDTrainer(OnlineDPOTrainer): + """ + Trainer for the Nash-MD method. + + It is implemented as a subclass of [`OnlineDPOTrainer`]. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model ([`PreTrainedModelWrapper`]): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation + and loss. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized. + reward_funcs ([`~transformers.PreTrainedModel`]): + The reward model to score completions with, preferably an + [`~transformers.AutoModelForSequenceClassification`]. + judge ([`BasePairwiseJudge`]): + The judge to use for pairwise comparison of model completions. + args ([`experimental.nash_md.NashMDConfig`]): + The NashMD config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the + sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + peft_config (`dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "nash-md"] + _name = "Nash-MD" + _paper = { + "title": "Nash Learning from Human Feedback", + "id": "2312.00886", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{munos2024nash, + title = {{Nash Learning from Human Feedback}}, + author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot}, + year = 2024, + booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=Y5AmNYiyCQ} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module = None, + ref_model: PreTrainedModel | nn.Module = None, + reward_funcs: PreTrainedModel | nn.Module | None = None, + judge: BasePairwiseJudge | None = None, + args: NashMDConfig | None = None, + data_collator: Callable | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + ) -> None: + super().__init__( + model=model, + ref_model=ref_model, + reward_funcs=reward_funcs, + judge=judge, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=processing_class, + peft_config=peft_config, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self._mixture_coef = self.args.mixture_coef + + # Overwrite the stats dictionary to include NashMD specific statistics + self.stats = { + # Remove "non_score_reward", "rlhf_reward", "scores_margin" + # Add "mixture_coef" + "loss/kl": [], + "objective/entropy": [], + "loss/score": [], + "rewards/probabilities": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + "val/model_contain_eos_token": [], + "val/ref_contain_eos_token": [], + "beta": [], + "mixture_coef": [], + } + if self.reward_funcs is not None: + if len(self.reward_funcs) != 1: + raise ValueError("NashMDTrainer only supports one reward function/model.") + self.reward_funcs = self.reward_funcs[0] + self.stats["rewards/chosen"] = [] + self.stats["rewards/rejected"] = [] + + @property + def mixture_coef(self): + if isinstance(self._mixture_coef, list): + epoch = self.state.epoch + return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1] + else: + return self._mixture_coef + + def _generate_completions(self, model, prompts): + # Generate completions from the policy model. + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx: + model_output = unwrapped_policy_for_gen_ctx.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + # Get the DDP/FSDP unwrapped version of the main model. + # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used). + policy_model_for_gmw = self.accelerator.unwrap_model(model) + + # Determine the correct reference model for GeometricMixtureWrapper. + # This also needs to be DDP/FSDP unwrapped. + ref_model_for_gmw: torch.nn.Module + if self.ref_model is None: + # No explicit ref_model is provided. + # Use the base of the main `model` if it's a PEFT model. + # policy_model_for_gmw is already DDP-unwrapped. + if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel): + ref_model_for_gmw = policy_model_for_gmw.get_base_model() + else: + # Not a PEFT model (or PEFT not available), or already a base model. + # Use the DDP-unwrapped policy model itself as the reference. + ref_model_for_gmw = policy_model_for_gmw + else: + # An explicit ref_model is provided. Unwrap it for DDP/FSDP. + ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model) + + # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped. + with torch.no_grad(): # Ensure no_grad context for mixture model generation + mixture_model = GeometricMixtureWrapper( + model=policy_model_for_gmw, + ref_model=ref_model_for_gmw, + generation_config=self.generation_config, + mixture_coef=self.mixture_coef, + device=self.accelerator.device, + ) + + mixture_output = mixture_model.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + return model_output, mixture_output + + def _process_completions(self, model_output, mixture_output, prompts): + context_length = prompts["input_ids"].shape[1] + + # Process model completions + model_completion_ids = model_output[:, context_length:] + model_completion_ids, model_completion_mask = truncate_right( + model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + model_data = { + "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), + "raw": prompts["raw"], + } + + # Process reference model completions + mixture_completion_ids = mixture_output[:, context_length:] + mixture_completion_ids, mixture_completion_mask = truncate_right( + mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + mixture_data = { + "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1), + "raw": prompts["raw"], + } + + return model_data, mixture_data + + def _compute_rewards(self, model_data, mixture_data, context_length): + with torch.no_grad(): + _, model_scores, _ = get_reward( + self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + _, mixture_scores, _ = get_reward( + self.reward_funcs, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + + # Apply EOS penalty if needed + if self.args.missing_eos_penalty is not None: + model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + model_scores[~model_contain_eos] -= self.args.missing_eos_penalty + mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty + + return model_scores, mixture_scores + + def _compute_judge(self, model_data, mixture_data, context_length): + prompts = model_data["raw"] + model_data_completions = self.processing_class.batch_decode( + model_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + model_data_completions = [completion.strip() for completion in model_data_completions] + + mixture_data_completions = self.processing_class.batch_decode( + mixture_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + mixture_data_completions = [completion.strip() for completion in mixture_data_completions] + if is_conversational({"prompt": prompts[0]}): + model_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in model_data_completions + ] + environment = jinja2.Environment() + template = environment.from_string(SIMPLE_CHAT_TEMPLATE) + prompts = [template.render(messages=message) for message in prompts] + model_data_completions = [template.render(messages=completion) for completion in model_data_completions] + + mixture_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in mixture_data_completions + ] + mixture_data_completions = [ + template.render(messages=completion) for completion in mixture_data_completions + ] + + probability = self.judge.judge( + prompts, + list(zip(model_data_completions, mixture_data_completions, strict=True)), + return_scores=True, + ) + return torch.tensor(probability, device=model_data["input_ids"].device) + + def _compute_logprobs(self, model, model_data, context_length): + def compute_logprobs_for_data(m, data): + output = m(data["input_ids"], attention_mask=data["attention_mask"]) + logits = output.logits[:, context_length - 1 : -1] + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) + return token_logprobs + + # Compute logprobs for model completions under the model + model_logprobs_model_data = compute_logprobs_for_data(model, model_data) + + # Compute logprobs of model completions under the reference model + with torch.no_grad(): + if self.ref_model is None: + with model.disable_adapter(): + ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) + else: + ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) + + # Mask padding tokens + model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 + model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + + return (model_logprobs_model_data, ref_logprobs_model_data) + + def _compute_losses( + self, + model_logprobs_model_data, + ref_logprobs_model_data, + probability, + ): + # reinforce score where 0.5 is a control variate + score = (probability - 0.5) * model_logprobs_model_data.sum(1) + + # kl divergence via reinforce + with torch.no_grad(): + log_ratio = model_logprobs_model_data - ref_logprobs_model_data + kl_div_log = log_ratio.sum(1) + kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1) + + # final loss + loss = self.beta * kl_div_loss - score + + return loss.mean(), score, kl_div_log + + def _log_statistics( + self, + model_data, + mixture_data, + model_logprobs_model_data, + ref_logprobs_model_data, + probability, + score, + kl_div, + context_length, + model_scores=None, + mixture_scores=None, + ): + # Helper function to gather and compute mean + def gather_mean(tensor): + return self.accelerator.gather_for_metrics(tensor).mean().item() + + # Log score + self.stats["loss/score"].append(gather_mean(score)) + # Log KL divergence + self.stats["loss/kl"].append(gather_mean(kl_div)) + + # Log logprobs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum)) + self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum)) + + # Log rewards + if self.reward_funcs is not None: + self.stats["rewards/chosen"].append(gather_mean(model_scores)) + self.stats["rewards/rejected"].append(gather_mean(mixture_scores)) + + # Log probabilities + self.stats["rewards/probabilities"].append(gather_mean(probability)) + + # Calculate entropy for model data + entropy_model_data = -model_logprobs_model_data.sum(1) + self.stats["objective/entropy"].append(gather_mean(entropy_model_data)) + + # Calculate margins + margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum + self.stats["rewards/margins"].append(gather_mean(margin)) + + # Calculate accuracy + accuracy = (margin > 0).float() + self.stats["rewards/accuracies"].append(gather_mean(accuracy)) + + # Log EOS token statistics + model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) + self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float())) + + # Log beta and mixture coef + self.stats["beta"].append(self.beta) + self.stats["mixture_coef"].append(self.mixture_coef) + + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + model.train() + + # Apply chat template and tokenize the input + batch_size = len(next(iter(inputs.values()))) + prompts = inputs["prompt"] + inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] + inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] + inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] + inputs = self.data_collator(inputs) + + # need the prompt_ only + inputs = self._prepare_inputs(inputs) + context_length = inputs["prompt_input_ids"].shape[1] + prompts = { + "input_ids": inputs["prompt_input_ids"], + "attention_mask": inputs["prompt_attention_mask"], + "raw": prompts, + } + del inputs + + # Sample completions from both the model and the reference model + model_output, mixture_output = self._generate_completions(model, prompts) + + # Process model completions + model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts) + + # Compute rewards + if self.reward_funcs is not None: + model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length) + # probability of the model data vs the mixture data + probability = F.sigmoid(model_scores - mixture_scores) + else: + model_scores, mixture_scores = None, None + probability = self._compute_judge(model_data, mixture_data, context_length) + + # Compute logprobs + model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length) + + # Compute loss + loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability) + + # Log everything + self._log_statistics( + model_data, + mixture_data, + model_logprobs_model_data.detach(), + ref_logprobs_model_data, + probability, + score.detach(), + kl_div.detach(), + context_length, + model_scores, + mixture_scores, + ) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps diff --git a/trl/trainer/nash_md_config.py b/trl/trainer/nash_md_config.py index ddc653ee619..9da8247f59f 100644 --- a/trl/trainer/nash_md_config.py +++ b/trl/trainer/nash_md_config.py @@ -12,35 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +import warnings +from dataclasses import dataclass -from .online_dpo_config import OnlineDPOConfig +from ..experimental.nash_md import NashMDConfig as _NashMDConfig @dataclass -class NashMDConfig(OnlineDPOConfig): - r""" - Configuration class for the [`NashMDTrainer`]. - - Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: - - Parameters: - mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): - Logit mixture coefficient for the model and reference model. If a list of floats is provided then the - mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the - epochs. - """ - - mixture_coef: list[float] = field( - default_factory=lambda: [0.5], - metadata={ - "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " - "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " - "rest of the epochs." - }, - ) - +class NashMDConfig(_NashMDConfig): def __post_init__(self): + warnings.warn( + "The `NashMDConfig` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.nash_md import NashMDConfig`. The current import path will be removed and no " + "longer supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223." + ) super().__post_init__() - if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: - self.mixture_coef = self.mixture_coef[0] diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 144bce6fbec..23aae32b9e5 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -12,483 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import textwrap -from collections.abc import Callable -from typing import Any +import warnings +from dataclasses import dataclass -import jinja2 -import torch -import torch.nn as nn -import torch.nn.functional as F -from datasets import Dataset, IterableDataset -from transformers import ( - BaseImageProcessor, - FeatureExtractionMixin, - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, - TrainerCallback, -) -from transformers.trainer_utils import EvalPrediction -from transformers.training_args import OptimizerNames -from transformers.utils import is_peft_available +from ..experimental.nash_md import NashMDTrainer as _NashMDTrainer -from ..data_utils import is_conversational, maybe_apply_chat_template -from ..models.modeling_base import GeometricMixtureWrapper -from ..models.utils import unwrap_model_for_generation -from .nash_md_config import NashMDConfig -from .online_dpo_trainer import OnlineDPOTrainer -from .utils import ( - SIMPLE_CHAT_TEMPLATE, - empty_cache, - get_reward, - selective_log_softmax, - truncate_right, -) - -if is_peft_available(): - from peft import PeftModel - - -class NashMDTrainer(OnlineDPOTrainer): - """ - Trainer for the Nash-MD method. - - It is implemented as a subclass of [`OnlineDPOTrainer`]. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an `AutoModelForCausalLM`. - ref_model ([`PreTrainedModelWrapper`]): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation - and loss. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized. - reward_funcs ([`~transformers.PreTrainedModel`]): - The reward model to score completions with, preferably an - [`~transformers.AutoModelForSequenceClassification`]. - judge ([`experimental.judges.BasePairwiseJudge`]): - The judge to use for pairwise comparison of model completions. - args ([`NashMDConfig`]): - The NashMD config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - peft_config (`dict`): - The peft config to use for training. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - """ - - _tag_names = ["trl", "nash-md"] - _name = "Nash-MD" - _paper = { - "title": "Nash Learning from Human Feedback", - "id": "2312.00886", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @inproceedings{munos2024nash, - title = {{Nash Learning from Human Feedback}}, - author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot}, - year = 2024, - booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, - publisher = {OpenReview.net}, - url = {https://openreview.net/forum?id=Y5AmNYiyCQ} - }"""), - } - - def __init__( - self, - model: PreTrainedModel | nn.Module = None, - ref_model: PreTrainedModel | nn.Module = None, - reward_funcs: PreTrainedModel | nn.Module | None = None, - judge=None, - args: NashMDConfig | None = None, - data_collator: Callable | None = None, - train_dataset: Dataset | IterableDataset | None = None, - eval_dataset: Dataset | dict[str, Dataset] | None = None, - processing_class: PreTrainedTokenizerBase - | BaseImageProcessor - | FeatureExtractionMixin - | ProcessorMixin - | None = None, - peft_config: dict | None = None, - compute_metrics: Callable[[EvalPrediction], dict] | None = None, - callbacks: list[TrainerCallback] | None = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - ) -> None: - super().__init__( - model=model, - ref_model=ref_model, - reward_funcs=reward_funcs, - judge=judge, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - reward_processing_classes=processing_class, - peft_config=peft_config, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - self._mixture_coef = self.args.mixture_coef - - # Overwrite the stats dictionary to include NashMD specific statistics - self.stats = { - # Remove "non_score_reward", "rlhf_reward", "scores_margin" - # Add "mixture_coef" - "loss/kl": [], - "objective/entropy": [], - "loss/score": [], - "rewards/probabilities": [], - "rewards/accuracies": [], - "rewards/margins": [], - "logps/chosen": [], - "logps/rejected": [], - "val/model_contain_eos_token": [], - "val/ref_contain_eos_token": [], - "beta": [], - "mixture_coef": [], - } - if self.reward_funcs is not None: - if len(self.reward_funcs) != 1: - raise ValueError("NashMDTrainer only supports one reward function/model.") - self.reward_funcs = self.reward_funcs[0] - self.stats["rewards/chosen"] = [] - self.stats["rewards/rejected"] = [] - - @property - def mixture_coef(self): - if isinstance(self._mixture_coef, list): - epoch = self.state.epoch - return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1] - else: - return self._mixture_coef - - def _generate_completions(self, model, prompts): - # Generate completions from the policy model. - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx: - model_output = unwrapped_policy_for_gen_ctx.generate( - input_ids=prompts["input_ids"], - attention_mask=prompts["attention_mask"], - generation_config=self.generation_config, - ) - - # Get the DDP/FSDP unwrapped version of the main model. - # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used). - policy_model_for_gmw = self.accelerator.unwrap_model(model) - - # Determine the correct reference model for GeometricMixtureWrapper. - # This also needs to be DDP/FSDP unwrapped. - ref_model_for_gmw: torch.nn.Module - if self.ref_model is None: - # No explicit ref_model is provided. - # Use the base of the main `model` if it's a PEFT model. - # policy_model_for_gmw is already DDP-unwrapped. - if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel): - ref_model_for_gmw = policy_model_for_gmw.get_base_model() - else: - # Not a PEFT model (or PEFT not available), or already a base model. - # Use the DDP-unwrapped policy model itself as the reference. - ref_model_for_gmw = policy_model_for_gmw - else: - # An explicit ref_model is provided. Unwrap it for DDP/FSDP. - ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model) - - # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped. - with torch.no_grad(): # Ensure no_grad context for mixture model generation - mixture_model = GeometricMixtureWrapper( - model=policy_model_for_gmw, - ref_model=ref_model_for_gmw, - generation_config=self.generation_config, - mixture_coef=self.mixture_coef, - device=self.accelerator.device, - ) - - mixture_output = mixture_model.generate( - input_ids=prompts["input_ids"], - attention_mask=prompts["attention_mask"], - generation_config=self.generation_config, - ) - - return model_output, mixture_output - - def _process_completions(self, model_output, mixture_output, prompts): - context_length = prompts["input_ids"].shape[1] - - # Process model completions - model_completion_ids = model_output[:, context_length:] - model_completion_ids, model_completion_mask = truncate_right( - model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id +@dataclass +class NashMDTrainer(_NashMDTrainer): + def __init__(self, *args, **kwargs): + warnings.warn( + "The `NashMDTrainer` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.nash_md import NashMDTrainer`. The current import path will be removed and no " + "longer supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223." ) - model_data = { - "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), - "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), - "raw": prompts["raw"], - } - - # Process reference model completions - mixture_completion_ids = mixture_output[:, context_length:] - mixture_completion_ids, mixture_completion_mask = truncate_right( - mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id - ) - mixture_data = { - "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1), - "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1), - "raw": prompts["raw"], - } - - return model_data, mixture_data - - def _compute_rewards(self, model_data, mixture_data, context_length): - with torch.no_grad(): - _, model_scores, _ = get_reward( - self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length - ) - _, mixture_scores, _ = get_reward( - self.reward_funcs, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length - ) - - # Apply EOS penalty if needed - if self.args.missing_eos_penalty is not None: - model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) - mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) - model_scores[~model_contain_eos] -= self.args.missing_eos_penalty - mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty - - return model_scores, mixture_scores - - def _compute_judge(self, model_data, mixture_data, context_length): - prompts = model_data["raw"] - model_data_completions = self.processing_class.batch_decode( - model_data["input_ids"][:, context_length:], skip_special_tokens=True - ) - model_data_completions = [completion.strip() for completion in model_data_completions] - - mixture_data_completions = self.processing_class.batch_decode( - mixture_data["input_ids"][:, context_length:], skip_special_tokens=True - ) - mixture_data_completions = [completion.strip() for completion in mixture_data_completions] - if is_conversational({"prompt": prompts[0]}): - model_data_completions = [ - [{"role": "assistant", "content": completion}] for completion in model_data_completions - ] - environment = jinja2.Environment() - template = environment.from_string(SIMPLE_CHAT_TEMPLATE) - prompts = [template.render(messages=message) for message in prompts] - model_data_completions = [template.render(messages=completion) for completion in model_data_completions] - - mixture_data_completions = [ - [{"role": "assistant", "content": completion}] for completion in mixture_data_completions - ] - mixture_data_completions = [ - template.render(messages=completion) for completion in mixture_data_completions - ] - - probability = self.judge.judge( - prompts, - list(zip(model_data_completions, mixture_data_completions, strict=True)), - return_scores=True, - ) - return torch.tensor(probability, device=model_data["input_ids"].device) - - def _compute_logprobs(self, model, model_data, context_length): - def compute_logprobs_for_data(m, data): - output = m(data["input_ids"], attention_mask=data["attention_mask"]) - logits = output.logits[:, context_length - 1 : -1] - token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) - return token_logprobs - - # Compute logprobs for model completions under the model - model_logprobs_model_data = compute_logprobs_for_data(model, model_data) - - # Compute logprobs of model completions under the reference model - with torch.no_grad(): - if self.ref_model is None: - with model.disable_adapter(): - ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) - else: - ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) - - # Mask padding tokens - model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 - model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) - ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) - - return (model_logprobs_model_data, ref_logprobs_model_data) - - def _compute_losses( - self, - model_logprobs_model_data, - ref_logprobs_model_data, - probability, - ): - # reinforce score where 0.5 is a control variate - score = (probability - 0.5) * model_logprobs_model_data.sum(1) - - # kl divergence via reinforce - with torch.no_grad(): - log_ratio = model_logprobs_model_data - ref_logprobs_model_data - kl_div_log = log_ratio.sum(1) - kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1) - - # final loss - loss = self.beta * kl_div_loss - score - - return loss.mean(), score, kl_div_log - - def _log_statistics( - self, - model_data, - mixture_data, - model_logprobs_model_data, - ref_logprobs_model_data, - probability, - score, - kl_div, - context_length, - model_scores=None, - mixture_scores=None, - ): - # Helper function to gather and compute mean - def gather_mean(tensor): - return self.accelerator.gather_for_metrics(tensor).mean().item() - - # Log score - self.stats["loss/score"].append(gather_mean(score)) - # Log KL divergence - self.stats["loss/kl"].append(gather_mean(kl_div)) - - # Log logprobs - model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) - ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) - - self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum)) - self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum)) - - # Log rewards - if self.reward_funcs is not None: - self.stats["rewards/chosen"].append(gather_mean(model_scores)) - self.stats["rewards/rejected"].append(gather_mean(mixture_scores)) - - # Log probabilities - self.stats["rewards/probabilities"].append(gather_mean(probability)) - - # Calculate entropy for model data - entropy_model_data = -model_logprobs_model_data.sum(1) - self.stats["objective/entropy"].append(gather_mean(entropy_model_data)) - - # Calculate margins - margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum - self.stats["rewards/margins"].append(gather_mean(margin)) - - # Calculate accuracy - accuracy = (margin > 0).float() - self.stats["rewards/accuracies"].append(gather_mean(accuracy)) - - # Log EOS token statistics - model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) - mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) - self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) - self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float())) - - # Log beta and mixture coef - self.stats["beta"].append(self.beta) - self.stats["mixture_coef"].append(self.mixture_coef) - - def training_step( - self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None - ) -> torch.Tensor: - model.train() - - # Apply chat template and tokenize the input - batch_size = len(next(iter(inputs.values()))) - prompts = inputs["prompt"] - inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] - inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] - inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] - inputs = self.data_collator(inputs) - - # need the prompt_ only - inputs = self._prepare_inputs(inputs) - context_length = inputs["prompt_input_ids"].shape[1] - prompts = { - "input_ids": inputs["prompt_input_ids"], - "attention_mask": inputs["prompt_attention_mask"], - "raw": prompts, - } - del inputs - - # Sample completions from both the model and the reference model - model_output, mixture_output = self._generate_completions(model, prompts) - - # Process model completions - model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts) - - # Compute rewards - if self.reward_funcs is not None: - model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length) - # probability of the model data vs the mixture data - probability = F.sigmoid(model_scores - mixture_scores) - else: - model_scores, mixture_scores = None, None - probability = self._compute_judge(model_data, mixture_data, context_length) - - # Compute logprobs - model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length) - - # Compute loss - loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability) - - # Log everything - self._log_statistics( - model_data, - mixture_data, - model_logprobs_model_data.detach(), - ref_logprobs_model_data, - probability, - score.detach(), - kl_div.detach(), - context_length, - model_scores, - mixture_scores, - ) - - if ( - self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - empty_cache() - - kwargs = {} - # For LOMO optimizers you need to explicitly use the learning rate - if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - kwargs["learning_rate"] = self._get_learning_rate() - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - self.accelerator.backward(loss, **kwargs) - - return loss.detach() / self.args.gradient_accumulation_steps + super().__init__(*args, **kwargs)