diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 88ffd313329..0d70dca5bf9 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -125,6 +125,8 @@ title: PPO - local: prm_trainer title: PRM + - local: sdft_trainer + title: SDFT - local: winrate_callback title: WinRateCallback - local: xpo_trainer diff --git a/docs/source/sdft_trainer.md b/docs/source/sdft_trainer.md new file mode 100644 index 00000000000..c80c521cfdb --- /dev/null +++ b/docs/source/sdft_trainer.md @@ -0,0 +1,70 @@ +# Self-Distillation Fine-Tuning (SDFT) Trainer + +## Overview + +Self-Distillation Fine-Tuning (SDFT) is described in [Self-Distillation for Language Models](https://arxiv.org/pdf/2601.19897). +SDFT trains a student model using a teacher model on the student's generated completions, using a divergence between +student and teacher distributions. + +The abstract from the paper is the following: + +> Continual learning, enabling models to acquire new skills and knowledge without degrading existing capabilities, remains a fundamental challenge for foundation models. While on-policy reinforcement learning can reduce forgetting, it requires explicit reward functions that are often unavailable. Learning from expert demonstrations, the primary alternative, is dominated by supervised fine-tuning (SFT), which is inherently offpolicy. We introduce Self-Distillation Fine-Tuning (SDFT), a simple method that enables on-policy learning directly from demonstrations. SDFT leverages in-context learning by using a demonstration-conditioned model as its own teacher, generating on-policy training signals that preserve prior capabilities while acquiring new skills. Across skill learning and knowledge acquisition tasks, SDFT consistently outperforms SFT, achieving higher new-task accuracy while substantially reducing catastrophic forgetting. In sequential learning experiments, SDFT enables a single model to accumulate multiple skills over time without performance regression, establishing on-policy distillation as a practical path to continual learning from demonstrations. + +> [!WARNING] +> **Experimental:** APIs under `trl.experimental` may change or be removed without notice. + +## Usage tips + +- Provide a teacher model via `ref_model`. If you omit it, the trainer will create a teacher from the same checkpoint + as the student. +- Your dataset must contain `prompt` and `teacher_prompt`. If you do not have distinct teacher prompts, set + `teacher_prompt = prompt`. +- Set `generate_from_teacher=True` to generate completions using the teacher model instead of the student. + +## Quick Start + +```python +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl.experimental.sdft import SDFTConfig, SDFTTrainer + +student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer.pad_token = tokenizer.eos_token + +train_dataset = Dataset.from_dict( + { + "prompt": ["Write a haiku about the ocean."], + "teacher_prompt": ["Write a haiku about the ocean."], + } +) + +training_args = SDFTConfig(output_dir="sdft-model", per_device_train_batch_size=1) +trainer = SDFTTrainer( + model=student_model, + ref_model=teacher_model, + args=training_args, + processing_class=tokenizer, + train_dataset=train_dataset, +) +trainer.train() +``` + +### Expected dataset type + +The dataset must be formatted with the following columns: + +- `prompt`: text or conversational messages for the student input. +- `teacher_prompt`: text or conversational messages for the teacher input. + +## SDFTTrainer + +[[autodoc]] experimental.sdft.SDFTTrainer + - train + - save_model + - push_to_hub + +## SDFTConfig + +[[autodoc]] experimental.sdft.SDFTConfig diff --git a/tests/experimental/test_sdft_trainer.py b/tests/experimental/test_sdft_trainer.py new file mode 100644 index 00000000000..0091802ce48 --- /dev/null +++ b/tests/experimental/test_sdft_trainer.py @@ -0,0 +1,126 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import Dataset +from transformers import AutoTokenizer + +from trl.experimental.sdft import SDFTConfig, SDFTTrainer + +from ..testing_utils import TrlTestCase + + +MODEL_ID = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + + +def build_dataset(): + return Dataset.from_dict( + { + "prompt": ["Write a short poem about the sea."], + "teacher_prompt": ["Write a short poem about the sea."], + } + ) + + +class TestSDFTTrainer(TrlTestCase): + def _build_args(self): + return SDFTConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + num_generations=1, + report_to="none", + ) + + def _build_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + return tokenizer + + def test_init_creates_default_teacher(self): + args = self._build_args() + tokenizer = self._build_tokenizer() + trainer = SDFTTrainer( + model=MODEL_ID, + args=args, + processing_class=tokenizer, + train_dataset=build_dataset(), + ) + assert trainer.ref_model is not None + assert trainer.ref_model is not trainer.model + + def test_init_with_ref_model_id(self): + args = self._build_args() + tokenizer = self._build_tokenizer() + trainer = SDFTTrainer( + model=MODEL_ID, + ref_model=MODEL_ID, + args=args, + processing_class=tokenizer, + train_dataset=build_dataset(), + ) + assert trainer.ref_model is not None + + def test_missing_teacher_prompt_raises(self): + args = self._build_args() + tokenizer = self._build_tokenizer() + bad_dataset = Dataset.from_dict({"prompt": ["Hello"]}) + with pytest.raises(ValueError, match="teacher_prompt"): + SDFTTrainer( + model=MODEL_ID, + args=args, + processing_class=tokenizer, + train_dataset=bad_dataset, + ) + + @pytest.mark.low_priority + def test_train_updates_student_and_freezes_teacher(self): + args = SDFTConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + num_generations=1, + max_completion_length=8, + max_steps=1, + logging_steps=1, + report_to="none", + save_strategy="no", + eval_strategy="no", + ) + tokenizer = self._build_tokenizer() + trainer = SDFTTrainer( + model=MODEL_ID, + args=args, + processing_class=tokenizer, + train_dataset=build_dataset(), + ) + + student_before = {n: p.detach().clone() for n, p in trainer.model.named_parameters()} + teacher_before = {n: p.detach().clone() for n, p in trainer.ref_model.named_parameters()} + + trainer.train() + + # Student params should change + student_changed = False + for name, before in student_before.items(): + after = trainer.model.get_parameter(name).detach() + if not torch.allclose(before, after): + student_changed = True + break + assert student_changed, "Student parameters did not update after training" + + # Teacher params should remain frozen + for name, before in teacher_before.items(): + after = trainer.ref_model.get_parameter(name).detach() + assert torch.allclose(before, after), f"Teacher parameter {name} changed during training" diff --git a/trl/experimental/sdft/__init__.py b/trl/experimental/sdft/__init__.py new file mode 100644 index 00000000000..85a7818ae5c --- /dev/null +++ b/trl/experimental/sdft/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .sdft_config import SDFTConfig +from .sdft_trainer import SDFTTrainer + + +__all__ = ["SDFTConfig", "SDFTTrainer"] diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py new file mode 100644 index 00000000000..497002be0ef --- /dev/null +++ b/trl/experimental/sdft/sdft_config.py @@ -0,0 +1,591 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from transformers import TrainingArguments + + +@dataclass +class SDFTConfig(TrainingArguments): + r""" + Configuration class for the [`SDFTTrainer`]. + + This class includes only the parameters that are specific to Self-Distillation Fine-Tuning (SDFT). For a full + list of training arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that + default values in this class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`str`, `dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`SDFTTrainer`] is provided as a string. + disable_dropout (`bool`, *optional*, defaults to `False`): + Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents + the model from generating different logprobs for the same input. + + > Parameters that control the data preprocessing + + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to only keep the columns needed by SDFT (`"prompt"` and `"teacher_prompt"`) in the dataset. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. + num_generations (`int` or `None`, *optional*, defaults to `8`): + Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size + * gradient_accumulation_steps) must be evenly divisible by this value. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum length of the generated completion. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible + with vLLM generation. + shuffle_dataset (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training dataset. + + > Parameters that control generation + + generation_batch_size: (`int`, *optional*): + Batch size to use for generation. If `None`, it defaults to the effective training batch size: + `per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one + generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`. + steps_per_generation: (`int`, *optional*): + Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive + with `generation_batch_size`. + temperature (`float`, defaults to `1.0`): + Temperature for sampling. The higher the temperature, the more random the completions. + top_p (`float`, *optional*, defaults to `1.0`): + Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to + `1.0` to consider all tokens. + top_k (`int`, *optional*): + Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is + disabled and all tokens are considered. + min_p (`float`, *optional*): + Minimum token probability, which will be scaled by the probability of the most likely token. It must be a + value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. + Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat + tokens. + use_transformers_paged (`bool`, *optional*, defaults to `False`): + Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers` + paged implementation will be used for generation instead of the default padded implementation. This + parameter is only effective when `use_vllm` is set to `False`. + cache_implementation (`str`, *optional*): + Implementation of the cache method for faster generation when `use_vllm` is set to `False`. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or + `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the + generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. + + > Parameters that control generation acceleration powered by vLLM + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation + instead of the default model.generate(). Requires `vllm` to be installed. + vllm_mode (`str`, *optional*, defaults to `"server"`): + Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or + `"colocate"`. + + - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve`). + - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a + separate server but may cause resource contention with training. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use + the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model + implementation. + + > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + + vllm_server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and + `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + + > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): + Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_tensor_parallel_size` flag. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken + for weight sync and generation. + + > Parameters that control the training + + beta (`float`, *optional*, defaults to `0.0`): + KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and improving + training speed. + num_iterations (`int`, *optional*, defaults to `1`): + Number of iterations per batch (denoted as μ in the algorithm). + mask_truncated_completions (`bool`, *optional*, defaults to `False`): + When enabled, truncated completions are excluded from the loss calculation, preventing them from being + incorrectly penalized and introducing noise during training. According to the + [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability. + sync_ref_model (`bool`, *optional*, defaults to `False`): + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originates from the + [TR-DPO](https://huggingface.co/papers/2404.09656) paper. + ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): + α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix + between the current policy and the previous reference policy during updates. The reference policy is + updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. + ref_model_sync_steps (`int`, *optional*, defaults to `512`): + τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how + frequently the current policy is synchronized with the reference policy. To use this parameter, you must + set `sync_ref_model=True`. + top_entropy_quantile (`float`, *optional*, defaults to `1.0`): + ρ parameter from [Beyond the 80/20 Rule](https://huggingface.co/papers/2506.01939). Keeps in the policy + loss term only the top-ρ quantile of tokens by entropy of the probability distribution at each sequence + position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token; + `1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with + `mask_truncated_completions=True`, only tokens from non-truncated completions are considered. + vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`): + Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed + logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL + Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework + (such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation + and training backends. TIS is proposed as a remedy for this issue. + vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): + Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance + sampling ratio, improving training stability. + + > Parameters that control the logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, + it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. + num_completions_to_print (`int`, *optional*): + Number of completions to print with `rich`. If `None`, all completions are logged. + wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): + Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all prompts + are logged. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + + # Parameters that control the model and reference model + model_init_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " + "argument of the `SDFTTrainer` is provided as a string." + }, + ) + disable_dropout: bool = field( + default=False, + metadata={ + "help": "Whether to disable dropout in the model. This is useful for training with a reference model, as " + "it prevents the model from generating different logprobs for the same input." + }, + ) + + # Parameters that control the data preprocessing + # The default value remove_unused_columns is overwritten from the parent class, because SDFT relies on custom + # columns like `teacher_prompt` (and sometimes multimodal inputs). + remove_unused_columns: bool | None = field( + default=False, + metadata={ + "help": "Whether to only keep the columns 'prompt' and 'teacher_prompt' in the dataset. If you use any " + "additional columns (e.g., images), you should keep this to `False`." + }, + ) + max_prompt_length: int | None = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." + }, + ) + num_generations: int | None = field( + default=8, + metadata={ + "help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size " + "* gradient_accumulation_steps) must be evenly divisible by this value." + }, + ) + max_completion_length: int | None = field( + default=256, + metadata={"help": "Maximum length of the generated completion."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " + "is not compatible with vLLM generation." + }, + ) + shuffle_dataset: bool | None = field( + default=True, + metadata={"help": "Whether to shuffle the training dataset."}, + ) + + # Parameters that control generation + generation_batch_size: int | None = field( + default=None, + metadata={ + "help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: " + "`per_device_train_batch_size * num_processes * steps_per_generation`." + }, + ) + steps_per_generation: int | None = field( + default=None, + metadata={"help": "Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`."}, + ) + temperature: float = field( + default=1.0, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + top_p: float = field( + default=1.0, + metadata={ + "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " + "Set to 1.0 to consider all tokens." + }, + ) + top_k: int | None = field( + default=None, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, " + "top-k-filtering is disabled and all tokens are considered." + }, + ) + min_p: float | None = field( + default=None, + metadata={ + "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " + "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." + }, + ) + generation_kwargs: dict | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or " + "`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the " + "generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that " + "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." + }, + ) + repetition_penalty: float = field( + default=1.0, + metadata={ + "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " + "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " + "to repeat tokens." + }, + ) + use_transformers_paged: bool = field( + default=False, + metadata={ + "help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the " + "`transformers` paged implementation will be used for generation instead of the default padded " + "implementation. This parameter is only effective when `use_vllm` is set to `False`." + }, + ) + cache_implementation: str | None = field( + default=None, + metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, + ) + + # Parameters that control generation acceleration powered by vLLM + use_vllm: bool = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for " + "generation instead of the default model.generate(). Requires `vllm` to be installed." + }, + ) + vllm_mode: str = field( + default="server", + metadata={ + "help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `'server'` or " + "`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure " + "a TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same " + "process and share the training GPUs. This avoids the need for a separate server but may cause resource " + "contention with training." + }, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={ + "help": "Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: " + "Use the `transformers` backend for model implementation. `vllm`: Use the `vllm` library for " + "model implementation." + }, + ) + vllm_enable_sleep_mode: bool = field( + default=False, + metadata={ + "help": "Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step " + "and woken for weight sync and generation." + }, + ) + # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + vllm_server_base_url: str | None = field( + default=None, + metadata={ + "help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " + "and `vllm_server_port` are ignored." + }, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_port: int = field( + default=8000, + metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={ + "help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " + "after the timeout, a `ConnectionError` is raised." + }, + ) + + # Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + vllm_gpu_memory_utilization: float = field( + default=0.3, + metadata={ + "help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_gpu_memory_utilization` flag." + }, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={ + "help": "Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_tensor_parallel_size` flag." + }, + ) + + # Parameters that control the training + beta: float = field( + default=0.0, + metadata={ + "help": "KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and " + "improving training speed." + }, + ) + alpha: float = field( + default=0.0, + metadata={ + "help": "Alpha coefficient. If `0.0` (default), the forward KL is used. If `1.0`, the reverse KL is used. If anything in between, the Jensen-Shannon Divergence is used." + }, + ) + generate_from_teacher: bool = field( + default=False, + metadata={ + "help": "If True, use the teacher model (ref_model) for generation. vLLM will be initialized with teacher " + "weights, enabling fast generation from the teacher. This makes training equivalent to online SFT " + "where the teacher generates completions and the student learns to reproduce them. " + "If False (default), use the student model for generation (standard RL behavior)." + }, + ) + num_iterations: int = field( + default=1, + metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, + ) + mask_truncated_completions: bool = field( + default=False, + metadata={ + "help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from " + "being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is " + "a good practice for training stability." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={ + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + top_entropy_quantile: float = field( + default=1.0, + metadata={ + "help": "ρ parameter from Beyond the 80/20 Rule. Keeps in the policy loss term only the top-ρ quantile of " + "tokens by entropy of the probability distribution at each sequence position, improving results. Range: " + "[0.0-1.0]. A value of `0.0` masks all but the highest entropy token; `1.0` keeps all tokens. The paper " + "recommends a value of `0.2`. If used with `mask_truncated_completions=True`, only tokens from " + "non-truncated completions are considered." + }, + ) + num_loss_tokens_to_skip: int = field( + default=0, + metadata={ + "help": "Number of tokens at the beginning of each completion to exclude from the loss calculation. " + "This can be useful to avoid penalizing the model for the initial tokens of the response, which may be " + "less predictable. A value of `0` (default) means all completion tokens are included in the loss." + }, + ) + vllm_importance_sampling_correction: bool = field( + default=True, + metadata={ + "help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and " + "recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL " + "Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy " + "effects due to subtle implementation differences between generation and training backends. TIS is " + "proposed as a remedy for this issue." + }, + ) + vllm_importance_sampling_cap: float = field( + default=2.0, + metadata={ + "help": "Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the " + "importance sampling ratio, improving training stability." + }, + ) + + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={ + "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " + "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." + }, + ) + num_completions_to_print: int | None = field( + default=None, + metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, + ) + wandb_log_unique_prompts: bool | None = field( + default=False, + metadata={ + "help": "Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, " + "all prompts are logged." + }, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() + + num_processes = self.world_size + # The current default effective batch size + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " + f"({self.per_device_train_batch_size * num_processes})." + ) + self.steps_per_generation = self.generation_batch_size // ( + self.per_device_train_batch_size * num_processes + ) + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" + ) + + if self.do_eval and self.eval_strategy != "no": + # Just ensure the value is divisible by the global batch size + if (self.per_device_eval_batch_size * num_processes) % self.num_generations != 0: + raise ValueError( + f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " + f"divisible by num_generations ({self.num_generations})." + ) + + # The generation batch must contain full prompt groups (no partials), so it must be divisible by + # num_generations. + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " + f"({self.num_generations})." + ) diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py new file mode 100644 index 00000000000..7a999028cfe --- /dev/null +++ b/trl/experimental/sdft/sdft_trainer.py @@ -0,0 +1,1735 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from collections import defaultdict, deque +from contextlib import nullcontext +from functools import partial +from pathlib import Path +from pprint import pformat +from typing import Any, Optional + +import datasets +import torch +import torch.utils.data +import transformers +from accelerate import logging +from accelerate.state import AcceleratorState +from accelerate.utils import broadcast_object_list, gather_object, is_peft_model, set_seed +from datasets import Dataset, IterableDataset +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.nn.functional import kl_div, log_softmax +from torch.utils.data import DataLoader, Sampler +from transformers import ( + AutoConfig, + AutoProcessor, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_wandb_available, +) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available + +from ...data_utils import is_conversational, maybe_apply_chat_template, prepare_multimodal_messages +from ...extras.profiling import profiling_context, profiling_decorator +from ...generation.vllm_client import VLLMClient +from ...import_utils import is_vllm_available +from ...models.utils import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + RepeatSampler, + disable_dropout_in_model, + ensure_master_addr_port, + entropy_from_logits, + identity, + nanmax, + nanmin, + pad, + selective_log_softmax, + shuffle_sequence_dict, + split_pixel_values_by_grid, + split_tensor_dict, + unsplit_pixel_values_by_grid, +) +from ..utils import prepare_peft_model +from .sdft_config import SDFTConfig + + +if is_peft_available(): + from peft import PeftConfig, PeftModel + +if is_vllm_available(): + from vllm import LLM, SamplingParams + +if is_wandb_available(): + import wandb + + +logger = logging.get_logger(__name__) + + +class MemoryEfficientSyncRefModelCallback(TrainerCallback): + """ + Memory-efficient callback to synchronize the model with a reference model. + + Unlike the default SyncRefModelCallback, this version iterates through parameters + one at a time instead of gathering all parameters at once. This reduces peak memory + usage from O(full_model_size) to O(single_param_size), making it feasible to sync + large models with DeepSpeed ZeRO-3. + """ + + def __init__( + self, + ref_model: PreTrainedModel | nn.Module, + accelerator: Any | None, + ): + self.accelerator = accelerator + self.ref_model = ref_model + + @staticmethod + def _sync_param(model_param, ref_param, alpha): + """Sync a single parameter: ref = alpha * model + (1 - alpha) * ref""" + ref_param.data.mul_(1.0 - alpha).add_(model_param.data, alpha=alpha) + + @staticmethod + def sync_target_model_memory_efficient(model, target_model, alpha): + """ + Sync target_model to track model, gathering one parameter at a time. + + This is O(1) in memory overhead instead of O(N) where N is model size. + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin + is_zero3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + + if is_zero3: + import deepspeed + + # Iterate through parameters one at a time + for (name, model_param), (_, ref_param) in zip( + model.named_parameters(), target_model.named_parameters(), strict=False + ): + # Gather only this pair of parameters + with deepspeed.zero.GatheredParameters([model_param, ref_param], modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + MemoryEfficientSyncRefModelCallback._sync_param(model_param, ref_param, alpha) + else: + # Non-ZeRO-3: just iterate normally + for model_param, ref_param in zip(model.parameters(), target_model.parameters(), strict=False): + MemoryEfficientSyncRefModelCallback._sync_param(model_param, ref_param, alpha) + + def on_step_end(self, args, state, control, **kwargs): + model: PreTrainedModel = kwargs["model"] + + if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: + if self.accelerator: + model = self.accelerator.unwrap_model(model) + self.sync_target_model_memory_efficient(model, self.ref_model, args.ref_model_mixup_alpha) + + +class SDFTTrainer(BaseTrainer): + """ + Trainer for the Self-Distillation method of Language Models. This algorithms is described + in the paper [Self-Distillation for Language Models](https://arxiv.org/pdf/2601.19897) + + Example: + + ```python + from datasets import Dataset + from transformers import AutoModelForCausalLM, AutoTokenizer + from trl.experimental.sdft import SDFTConfig, SDFTTrainer + + student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct") + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + tokenizer.pad_token = tokenizer.eos_token + + dataset = Dataset.from_dict( + { + "prompt": ["Write a haiku about the ocean."], + "teacher_prompt": ["Write a haiku about the ocean."], + } + ) + + training_args = SDFTConfig(output_dir="sdft-model", per_device_train_batch_size=1) + trainer = SDFTTrainer( + model=student_model, + ref_model=teacher_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in + `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + ref_model (`Union[str, PreTrainedModel]`, *optional*): + Teacher model used for distillation. If provided as a string, it is loaded with + [`~transformers.AutoModelForCausalLM.from_pretrained`] using `args.model_init_kwargs`. If `None`, the + trainer will instantiate a teacher model from the same checkpoint as the student. + args ([`SDFTConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include columns `"prompt"` and `"teacher_prompt"`. Additional + columns are ignored unless used for multimodal inputs (e.g., `image` or `images`). The format of the + samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A + padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, + `tokenizer.eos_token` will be used as the default. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "sdft"] + _name = "SDFT" + + def __init__( + self, + model: str | PreTrainedModel, + ref_model: str | PreTrainedModel | None = None, + args: SDFTConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: Optional["PeftConfig"] = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = SDFTConfig(f"{model_name}-SDFT") + + # Models + # Trained model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str): # it's a str, but not "auto" + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype + else: + raise ValueError( + "Invalid `dtype` passed to `SDFTConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `SDFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it + # Inspect the forward method before we wrap the model with PEFT + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): + model = prepare_peft_model(model, peft_config, args) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left") + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + + # Training arguments + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.num_generations = args.num_generations + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_transformers_paged = args.use_transformers_paged + self.use_vllm = args.use_vllm + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction + self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap + self.mask_truncated_completions = args.mask_truncated_completions + self.top_entropy_quantile = args.top_entropy_quantile + self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip + + # Datasets + self.shuffle_dataset = args.shuffle_dataset + + if ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) + ): + # See https://github.com/huggingface/trl/issues/3213 + raise NotImplementedError( + "Iterable datasets are not yet supported in SDFTTrainer. Please use a standard dataset instead." + ) + self._validate_dataset_columns(train_dataset, "train_dataset") + self._validate_dataset_columns(eval_dataset, "eval_dataset") + + # Multi-step + self.num_iterations = args.num_iterations + # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in GRPO-like algorithms, the sampled data does not include the + # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: + # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To + # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. + # This acts as a flag to indicate that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + super().__init__( + model=model, + args=args, + data_collator=identity, # No data collation is needed in SDFT + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func` + # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the + # global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The + # simplest (though a bit hacky) way is to set `compute_loss_func` to any non-None value, which bypasses + # that behavior without rewriting `training_step`. + compute_loss_func="non-None value to disable scaling", + ) + + # Reference model + self.beta = args.beta + self.alpha = args.alpha + self.generate_from_teacher = args.generate_from_teacher + if isinstance(ref_model, str): + ref_model_id = ref_model + config = AutoConfig.from_pretrained(ref_model_id) + architecture = getattr(transformers, config.architectures[0]) + ref_model = architecture.from_pretrained(ref_model_id, **model_init_kwargs) + elif ref_model is None: + if not model_id: + raise ValueError( + "SDFTTrainer could not infer a teacher checkpoint from the student model. " + "Please pass `ref_model` explicitly." + ) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + elif not isinstance(ref_model, PreTrainedModel): + raise TypeError("`ref_model` must be a model id or a PreTrainedModel instance.") + + self.ref_model = ref_model + self.ref_model.eval() + for param in self.ref_model.parameters(): + param.requires_grad_(False) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + # Keep logs sized to the generation batch to record only outputs from the latest model update. + self._logs = { + "images": deque(maxlen=args.generation_batch_size), + "prompt": deque(maxlen=args.generation_batch_size), + "completion": deque(maxlen=args.generation_batch_size), + } + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install trl[vllm]` to use it." + ) + + if self.vllm_mode == "server": + if self.accelerator.is_main_process: + if args.vllm_server_base_url is not None: + base_url = args.vllm_server_base_url + else: + base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client.init_communicator(device=torch.cuda.current_device()) + + elif self.vllm_mode == "colocate": + # Make sure vllm_tensor_parallel_size group size evenly divides the world size - each group should have + # the same number of ranks + if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: + raise ValueError( + f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " + f"({self.accelerator.num_processes}) evenly." + ) + + if self.vllm_tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. + # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( + [ + list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) + for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) + ] + ) + + # vLLM requires the environment variables to be set for distributed training. + os.environ["RANK"] = str(self.accelerator.process_index) + os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) + os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) + # Ensure distributed rendezvous variables are set without colliding across concurrent runs + ensure_master_addr_port() + + if self.max_prompt_length is not None and self.max_completion_length is not None: + max_model_len = self.max_prompt_length + self.max_completion_length + else: + max_model_len = None + # Use teacher model for vLLM when generate_from_teacher=True + if self.generate_from_teacher and self.ref_model is None: + raise ValueError("`generate_from_teacher=True` requires a teacher model.") + vllm_model_path = self.ref_model.name_or_path if self.generate_from_teacher else model.name_or_path + logger.info( + f"[DEBUG] Initializing vLLM with model: {vllm_model_path}, " + f"generate_from_teacher={self.generate_from_teacher}" + ) + self.llm = LLM( + model=vllm_model_path, + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=self.vllm_gpu_memory_utilization, + max_num_seqs=self.args.per_device_train_batch_size + * self.vllm_tensor_parallel_size + * self.args.steps_per_generation, + max_model_len=max_model_len, + distributed_executor_backend="external_launcher", + # Feed identical seed for tp groups to ensure sampling results are the same across workers + seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory + max_num_batched_tokens=4096, + model_impl=self.args.vllm_model_impl, + enable_sleep_mode=self.args.vllm_enable_sleep_mode, + # Important so temperature scaling/logit tweaking affects the TIS log probs + logprobs_mode="processed_logprobs", + ) + if self.args.vllm_enable_sleep_mode: + self.llm.sleep(level=1) + else: + raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") + + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + generation_kwargs = { + "max_new_tokens": self.max_completion_length, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": self.min_p, + "repetition_penalty": self.repetition_penalty, + "cache_implementation": args.cache_implementation, + } + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**generation_kwargs) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + self.add_callback( + MemoryEfficientSyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator) + ) + + def _validate_dataset_columns(self, dataset, name: str) -> None: + if dataset is None: + return + if isinstance(dataset, dict): + for key, sub_dataset in dataset.items(): + self._validate_dataset_columns(sub_dataset, f"{name}[{key}]") + return + if not hasattr(dataset, "column_names"): + return + required_columns = {"prompt", "teacher_prompt"} + existing_columns = set(dataset.column_names) + missing = required_columns - existing_columns + if missing: + missing_list = ", ".join(sorted(missing)) + required_list = ", ".join(sorted(required_columns)) + raise ValueError( + f"{name} must include columns [{required_list}]. Missing [{missing_list}]. " + "If you do not have distinct teacher prompts, set `teacher_prompt` to the same value as `prompt`." + ) + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In SDFTTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by the `training_step` method, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt", "teacher_prompt", "image", "images"] + + # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. + # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an + # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions + # once every steps_per_generation step—rather than once per accumulation step—which is significantly more + # efficient. The only change from the original implementation is multiplying the batch size by + # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the + # splitting internally. + # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line + # modification. As a result, some parts of the method aren't relevant to Distil, but we keep them to stay one line + # apart from the super method, ensuring easier maintenance in the future. + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: + # Returns a sampler that + # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are + # distributed to different GPUs, allowing group-wise statistics to be computed consistently. Using the + # same seed across processes ensures consistent prompt assignment, preventing discrepancies in group + # formation. + # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to + # _prepare_inputs to see how the generations are stored and reused. + + # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the + # second row shows the second sampled batch, and so on. + # + # | GPU 0 | GPU 1 | + # + # global_step step <-───> num_generations=2 + # <-───────> per_device_train_batch_size=3 + # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss + # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss + # | + # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss + # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss + # + # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss + # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss + # ... + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + # See _get_train_sampler for an explanation of the sampler. + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=self.num_generations, + seed=self.args.seed, + ) + + def _log_prompt_completions_sample(self) -> None: + if not self.log_completions: + return + num_samples = self.num_completions_to_print or len(self._logs["prompt"]) + for idx, (prompt, completion) in enumerate(zip(self._logs["prompt"], self._logs["completion"], strict=False)): + if idx >= num_samples: + break + prompt_text = pformat(prompt, width=100) + completion_text = pformat(completion, width=100) + logger.info("SDFT sample %s\nPrompt:\n%s\nCompletion:\n%s", idx, prompt_text, completion_text) + + @profiling_decorator + def _get_last_hidden_state( + self, + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + pixel_values=None, + image_grid_thw=None, + pixel_attention_mask=None, + image_sizes=None, + ): + if is_peft_model(unwrapped_model): + unwrapped_model = unwrapped_model.base_model.model + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + + # For Qwen models: + if image_grid_thw is not None and pixel_values is not None: + model_inputs["image_grid_thw"] = image_grid_thw + # For Gemma, SmolVLM2, LLaVa-Next etc.: + if pixel_values is not None: + model_inputs["pixel_values"] = pixel_values + # For SmolVLM2 + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask + # For LLaVa-Next + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state + # Exclude the last value: it corresponds to the next token pred + last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + return last_hidden_state + + def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: + """ + Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. + + Args: + entropies (`torch.Tensor`): + Tensor of shape (batch_size, seq_len) with per-token entropy values. + mask (`torch.Tensor`): + Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding. + threshold (`float`): + Quantile threshold between `0.0` and `1.0` to select high-entropy tokens. + + Returns: + `torch.Tensor`: + Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold + and `False` otherwise. + """ + local = entropies[mask.bool()].float() + + # Use a negative pad_value as a sentinel because entropy values are always >= 0. + # This guarantees that the sentinel cannot collide with any real entropy value. + pad_value = -1e9 + + # Pad across processes so that every rank has the same tensor length + padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) + gathered = self.accelerator.gather(padded) + + # Drop sentinel values (safe because no entropy can be negative) + gathered = gathered[gathered != pad_value] + + if gathered.numel() == 0: + return torch.zeros_like(entropies, dtype=torch.bool) + + entropy_threshold = torch.quantile(gathered, threshold) + masked_entropies = entropies * mask.float() + entropy_mask = masked_entropies >= entropy_threshold + return entropy_mask & mask.bool() # ensure padding tokens are always masked out + + @profiling_decorator + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, + num_images=None, + pixel_attention_mask=None, + image_sizes=None, + token_type_ids=None, + compute_all_logps=True, + ) -> dict[str, torch.Tensor | None]: + """Compute log-probs and (optionally) entropies for each token.""" + batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + all_selected_logps = [] + all_logps = [] + all_entropies = [] + for start in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[start : start + batch_size] + attention_mask_batch = attention_mask[start : start + batch_size] + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes[start : start + batch_size] + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + logits = model(**model_inputs).logits + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + + completion_ids = input_ids_batch[:, -logits_to_keep:] + selected_logps = selective_log_softmax(logits, completion_ids) # compute logprobs + if compute_all_logps: + logps = log_softmax(logits, dim=-1) + else: + logps = None + all_selected_logps.append(selected_logps) + all_logps.append(logps) + + if compute_entropy: + with torch.no_grad(): + entropies = entropy_from_logits(logits) + all_entropies.append(entropies) + + selected_logps = torch.cat(all_selected_logps, dim=0) + if compute_all_logps: + logps = torch.cat(all_logps, dim=0) + else: + logps = None + entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + return selected_logps, logps, entropies + + def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None): + extra_prefixes = extra_prefixes or [] + prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes + for prefix in prefixes: + name = name.replace(prefix, "") + return name + + def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): + """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" + # For FSDP1, we need to recurse into children and also use summon_full_params + if visited is None: + visited = set() + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + self._sync_fsdp1_params_to_vllm( + child_module, prefix=child_prefix, visited=visited + ) # recurse into the child + + if isinstance(module, FSDP): + with FSDP.summon_full_params(module, recurse=False, writeback=False): + for param_name, param in module.named_parameters(): + full_name = f"{prefix}.{param_name}" if prefix else param_name + full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) + + if full_name in visited: + continue # skip FSDP subtrees already traversed + visited.add(full_name) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(full_name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(full_name, param.data)]) + + def _sync_fsdp2_params_to_vllm(self, module: nn.Module): + # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion + for name, param in module.state_dict().items(): + if param.is_cpu: + param = param.to(torch.device("cuda")) + param = param.full_tensor() + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param)]) + + @profiling_decorator + def _move_model_to_vllm(self): + # Select which model to sync to vLLM: teacher (ref_model) or student (model) + # When generate_from_teacher=True, sync the teacher model since vLLM was initialized with teacher weights + model_to_sync = self.ref_model if self.generate_from_teacher else self.model + + # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + + if is_peft_model(self.model): + if self.generate_from_teacher: + raise ValueError( + "PEFT model handling only applies when syncing student model (teacher is typically not PEFT)" + ) + # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as + # merging adapters in a sharded manner is not supported. + # TODO: does this work with FSDP? + with gather_if_zero3(list(self.model.parameters())): + self.model.merge_adapter() + + # Update vLLM weights while parameters are gathered + if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext + # Update vLLM weights while parameters are gathered + # For PEFT with FSDP we need to use the memory efficient post-order traversal + fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm( + self.model + ) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(self.model) + else: + # DeepSpeed ZeRO-3 with PEFT + for name, param in self.model.named_parameters(): + # When using PEFT, we need to recover the original parameter name and discard some parameters + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + if self.model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + # Unmerge adapters while parameters are still gathered + self.model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + # For non-PEFT models, simply gather (if needed) and update each parameter individually. + if self.is_fsdp_enabled: + fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm( + model_to_sync + ) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(model_to_sync) + else: + for name, param in model_to_sync.named_parameters(): + name = self._fix_param_name_to_vllm(name) + with gather_if_zero3([param]): + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == "colocate": + self.llm.reset_prefix_cache() + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + # Prepares inputs for model training/evaluation by managing completion generation and batch handling. + # During training: + # - Receives the local generation batch (Per-GPU batch size × steps per generation) + # from the modified training dataloader instead of the standard local batch + # - Generates completions once for the entire generation batch and splits it into batches of size + # `per_device_train_batch_size` + # - Buffers these completions and returns the appropriate slice for the current accumulation step + # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # During evaluation: + # - The input is treated as a standard local batch (no accumulation, no multiple iterations) + # - Completions are generated for each batch without buffering or reuse + # Returns a single local batch in both cases. + + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + # self._buffered_inputs=None can occur when resuming from a checkpoint + generation_batch = self._generate_and_score_completions(generation_batch) + generation_batch = split_pixel_values_by_grid(generation_batch) + generation_batch = shuffle_sequence_dict(generation_batch) + generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) + self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] + inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] + self._step += 1 + else: + # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence + # local generation batch == local eval batch + inputs = self._generate_and_score_completions(generation_batch) + return inputs + + def _generate_single_turn(self, prompts: list[str], images: list | None): + device = self.accelerator.device + + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] + kwargs = {} + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images, strict=False): + if isinstance(prompt, list): # i.e., when using conversational data + prepare_multimodal_messages(prompt, num_images=len(image_list)) + + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + + if images is not None: + prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # Generate completions using either vLLM or regular generation + # Note: When generate_from_teacher=True, vLLM is initialized with teacher weights + if self.use_vllm: + if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: + # wake up colocated vLLM instances if needed + torch.cuda.empty_cache() # required to avoid OOM in some cases + self.llm.wake_up() + + # First, update the vLLM weights if needed + # When generate_from_teacher=True and sync_ref_model=False, teacher is static so no sync needed + # (vLLM already loaded teacher weights at initialization) + should_sync = self.state.global_step != self._last_loaded_step + if self.generate_from_teacher and not self.args.sync_ref_model: + should_sync = False # Teacher is static, no need to sync + if should_sync: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + if self.vllm_mode == "server": + all_prompts_text = gather_object(prompts_text) + if images is not None: + all_images = gather_object(images) + + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + + if images is not None: + ordered_set_of_images = all_images[:: self.num_generations] + else: + ordered_set_of_images = None + + with profiling_context(self, "vLLM.generate"): + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, + generation_kwargs=self.args.generation_kwargs, + ) + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) + else: + payload = None + + # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. + obj_list = [payload] + broadcast_object_list(obj_list, from_process=0) + all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] + + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] + + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + prompt_ids = all_prompt_ids[process_slice] + completion_ids = all_completion_ids[process_slice] + logprobs = all_logprobs[process_slice] + + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts + elif self.vllm_mode == "colocate": + generation_kwargs = { + "n": 1, # vLLM on each GPU generates only 1 in colocate mode + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": -1 if self.top_k is None else self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "truncate_prompt_tokens": self.max_prompt_length, + "logprobs": 0, # only return the logprob of the generated token + } + if self.args.generation_kwargs is not None: + generation_kwargs.update(self.args.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) + + if self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts_text) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) + all_prompts_text = [p for sublist in gathered_prompts for p in sublist] + + if images is not None: + gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) + all_images = [img for sublist in gathered_images for img in sublist] + else: + all_images = None + else: + all_prompts_text = prompts_text + all_images = images + + if images is not None and all_images: + vllm_inputs = [] + for prompt, image_list in zip(all_prompts_text, all_images, strict=False): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + + else: + vllm_inputs = all_prompts_text + + with profiling_context(self, "vLLM.generate"): + all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_logprobs = [ + [next(iter(lp.values())).logprob for lp in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] + + if self.vllm_tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] + else: + prompt_ids = all_prompt_ids + completion_ids = all_completion_ids + logprobs = all_logprobs + + if self.args.vllm_enable_sleep_mode: + self.llm.sleep(level=1) + + elif self.use_transformers_paged: + # Re-process inputs for paged generation if needed + # Note: images are already validated and preprocessed above + paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) + previous_attn = self.model_wrapped.config._attn_implementation + + if is_flash_attn_2_available(): + self.model_wrapped.config._attn_implementation = "paged_attention" + else: + self.model_wrapped.config._attn_implementation = "sdpa_paged" + with ( + profiling_context(self, "transformers.generate_batch"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + # Cast to the appropriate dtype based on training configuration + if self.args.bf16: + unwrapped_model.to(torch.bfloat16) + elif self.args.fp16: + unwrapped_model.to(torch.float16) + with torch.inference_mode(): + all_outputs = unwrapped_model.generate_batch( + paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False + ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode + completion_ids = [output.generated_tokens for output in all_outputs.values()] + prompt_ids = paged_prompt_inputs.input_ids + # Restore the original attention implementation, training mode + self.model_wrapped.config._attn_implementation = previous_attn + logprobs = None # not used in this case + + else: + # Regular generation path + generate_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + **kwargs, + ) + generate_inputs = super()._prepare_inputs(generate_inputs) + + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config, disable_compile=True + ) + # Compute prompt length and extract completion ids + prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] + prompt_length = prompt_ids.size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] + logprobs = None # not used in this case + + return prompt_ids, completion_ids, logprobs, forward_kwargs + + def _generate(self, prompts: list[str], images: list | None): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) + + # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss + + # Log the metrics + if mode == "train": + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + teacher_prompts = [x["teacher_prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + # Decide whether to generate from teacher (with context) or student (without context) + generation_prompts = teacher_prompts if self.generate_from_teacher else prompts + + ( + _generation_prompt_ids_list, # Discard - we'll compute student/teacher prompt IDs separately + completion_ids_list, + num_items_in_batch, + sampling_per_token_logps_list, + forward_kwargs, + ) = self._generate(generation_prompts, images) + + # Process student prompts (always used for student training, regardless of generation source) + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + if self.use_vllm: + self.processing_class.truncation_side = "left" + student_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + student_inputs = super()._prepare_inputs(student_inputs) + student_prompt_ids, student_prompt_mask = student_inputs["input_ids"], student_inputs["attention_mask"] + prompt_ids_list = [p[m].tolist() for p, m in zip(student_prompt_ids, student_prompt_mask.bool(), strict=False)] + + # Process teacher prompts (always used for teacher, regardless of generation source) + teacher_prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] + for prompt in teacher_prompts + ] + teacher_inputs = self.processing_class( + text=teacher_prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + teacher_inputs = super()._prepare_inputs(teacher_inputs) + if self.use_vllm: + self.processing_class.truncation_side = "right" + teacher_prompt_ids, teacher_prompt_mask = teacher_inputs["input_ids"], teacher_inputs["attention_mask"] + teacher_prompt_ids_list = [ + p[m].tolist() for p, m in zip(teacher_prompt_ids, teacher_prompt_mask.bool(), strict=False) + ] + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + teacher_prompt_ids = [torch.tensor(ids, device=device) for ids in teacher_prompt_ids_list] + teacher_prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in teacher_prompt_ids] + teacher_prompt_ids = pad(teacher_prompt_ids, padding_value=self.pad_token_id, padding_side="left") + teacher_prompt_mask = pad(teacher_prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + teacher_prompt_completion_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) # (B, P+C) + teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) # (B, P+C) + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + with torch.no_grad(): + # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of + # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the + # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps + # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set + # old_per_token_logps to None. + # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the + # distribution mismatch between vLLM and the training model can be large and harm the training. + # Skip when generate_from_teacher=True since importance sampling is not used in that case. + generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency + if not self.generate_from_teacher and ( + self.args.gradient_accumulation_steps % generate_every != 0 + or (self.use_vllm and self.vllm_importance_sampling_correction) + ): + old_per_token_logps, _, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + compute_all_logps=False, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + # Skip when generate_from_teacher=True since vLLM has teacher weights (no mismatch to correct) + if self.use_vllm and self.vllm_importance_sampling_correction and not self.generate_from_teacher: + importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) + importance_sampling_ratio = torch.clamp( + importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + else: + importance_sampling_ratio = None + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + compute_all_logps=False, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps, _, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + compute_all_logps=False, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + ref_per_token_logps = None + + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text, strict=False): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + # Log prompt and completion texts + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) + + if images is not None: + self._logs["images"].extend(gather_object(images)) + + if importance_sampling_ratio is not None: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + delta = delta[completion_mask.bool()] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + + flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + ) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "teacher_prompt_ids": teacher_prompt_ids, + "teacher_prompt_mask": teacher_prompt_mask, + "num_items_in_batch": num_items_in_batch, + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + if importance_sampling_ratio is not None: + output["importance_sampling_ratio"] = importance_sampling_ratio + if ref_per_token_logps is not None: + output["ref_per_token_logps"] = ref_per_token_logps + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if images is not None: + output["num_images"] = num_images + return output + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The SDFTTrainer does not support returning outputs") + return self._compute_loss(model, inputs) + + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + teacher_prompt_ids, teacher_prompt_mask = inputs["teacher_prompt_ids"], inputs["teacher_prompt_mask"] + + # Create a separate mask for loss computation that skips the first N tokens + # Note: completion_mask is used for both attention (forward pass) and loss computation + # We need to keep the original for attention, but create a modified one for loss + loss_completion_mask = completion_mask + if self.num_loss_tokens_to_skip > 0: + batch_size, seq_len = completion_mask.shape + # Create a mask that is 0 for the first num_loss_tokens_to_skip tokens and 1 elsewhere + token_positions = torch.arange(seq_len, device=completion_mask.device).unsqueeze(0).expand(batch_size, -1) + skip_mask = (token_positions >= self.num_loss_tokens_to_skip).int() + # Apply the skip mask (only mask tokens that were originally unmasked) + loss_completion_mask = completion_mask * skip_mask + + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + teacher_input_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, all_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + ) + + with torch.no_grad(): + teacher_per_token_logps, teacher_all_logps, teacher_entropies = self._get_per_token_logps_and_entropies( + self.ref_model, + teacher_input_ids, + teacher_attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + ) + + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, loss_completion_mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + + # Compute KL divergences using F.kl_div + # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. + if self.alpha == 0: # Forward KL + kl_loss = kl_div(all_logps, teacher_all_logps, reduction="none", log_target=True) + elif self.alpha == 1: # Reverse KL + kl_loss = kl_div(teacher_all_logps, all_logps, reduction="none", log_target=True) + else: + # Compute the log of the mixture distribution + # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture + alpha = torch.tensor(self.alpha, dtype=all_logps.dtype) + mixture_log_probs = torch.logsumexp( + torch.stack([all_logps + torch.log(1 - alpha), teacher_all_logps + torch.log(alpha)]), + dim=0, + ) + + kl_teacher = kl_div(mixture_log_probs, teacher_all_logps, reduction="none", log_target=True) + kl_student = kl_div(mixture_log_probs, all_logps, reduction="none", log_target=True) + + # Compute the Generalized Jensen-Shannon Divergence + kl_loss = alpha * kl_teacher + (1 - alpha) * kl_student + per_token_loss = kl_loss.sum(-1) + + if self.use_vllm and self.vllm_importance_sampling_correction and not self.generate_from_teacher: + ratio = inputs["importance_sampling_ratio"] + importance_weights = (ratio * loss_completion_mask).sum(-1) / loss_completion_mask.sum(-1).clamp(min=1.0) + importance_weights = importance_weights.unsqueeze(-1) + per_token_loss = per_token_loss * importance_weights + + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + + loss = ((per_token_loss * loss_completion_mask).sum(-1) / loss_completion_mask.sum(-1).clamp(min=1.0)).mean() + loss = loss / self.current_gradient_accumulation_steps + + # Log the metrics + mode = "train" if self.model.training else "eval" + + with torch.no_grad(): + kl_approx = ( + (per_token_logps - teacher_per_token_logps) + torch.exp(teacher_per_token_logps - per_token_logps) - 1 + ) + kl_approx_mean = (kl_approx * loss_completion_mask).sum() / loss_completion_mask.sum() + self._metrics[mode]["kl_approx"].append(self.accelerator.gather(kl_approx_mean).nanmean().item()) + + loss_completion_token_count = loss_completion_mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # already reduced to sequence-level + return x.mean() + else: + return (x * loss_completion_mask).sum() / loss_completion_token_count + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + self._metrics[mode]["kl_to_base_model"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + mean_entropy = masked_batch_mean(entropies) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + return loss + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + loss = loss.mean().detach() + return loss, None, None + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + self._log_prompt_completions_sample() + + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + import pandas as pd + + table = { + "step": [str(self.state.global_step)] * len(self._logs["prompt"]), + "prompt": self._logs["prompt"], + "completion": self._logs["completion"], + } + + if self._logs["images"]: + table["images"] = [] + for image_list in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append([wandb.Image(image) for image in image_list]) + + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + wandb.log({"completions": wandb.Table(dataframe=df)}) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial)