diff --git a/docs/source/sdft_trainer.md b/docs/source/sdft_trainer.md index c80c521cfdb..ca11b48267a 100644 --- a/docs/source/sdft_trainer.md +++ b/docs/source/sdft_trainer.md @@ -2,9 +2,8 @@ ## Overview -Self-Distillation Fine-Tuning (SDFT) is described in [Self-Distillation for Language Models](https://arxiv.org/pdf/2601.19897). -SDFT trains a student model using a teacher model on the student's generated completions, using a divergence between -student and teacher distributions. +Self-Distillation Fine-Tuning (SDFT) is described in [Self-Distillation Enables Continual Learning](https://huggingface.co/papers/2601.19897) by Idan Shenfeld, Mehul Damani, [Jonas Hübotter](https://huggingface.co/jonhue), Pulkit Agrawal. +SDFT trains a student model using a teacher model on the student's generated completions, using a divergence between student and teacher distributions. The abstract from the paper is the following: @@ -13,26 +12,12 @@ The abstract from the paper is the following: > [!WARNING] > **Experimental:** APIs under `trl.experimental` may change or be removed without notice. -## Usage tips - -- Provide a teacher model via `ref_model`. If you omit it, the trainer will create a teacher from the same checkpoint - as the student. -- Your dataset must contain `prompt` and `teacher_prompt`. If you do not have distinct teacher prompts, set - `teacher_prompt = prompt`. -- Set `generate_from_teacher=True` to generate completions using the teacher model instead of the student. - ## Quick Start ```python from datasets import Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer from trl.experimental.sdft import SDFTConfig, SDFTTrainer -student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") -teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct") -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") -tokenizer.pad_token = tokenizer.eos_token - train_dataset = Dataset.from_dict( { "prompt": ["Write a haiku about the ocean."], @@ -42,10 +27,8 @@ train_dataset = Dataset.from_dict( training_args = SDFTConfig(output_dir="sdft-model", per_device_train_batch_size=1) trainer = SDFTTrainer( - model=student_model, - ref_model=teacher_model, + model="Qwen/Qwen2-0.5B-Instruct", args=training_args, - processing_class=tokenizer, train_dataset=train_dataset, ) trainer.train() diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py index 497002be0ef..a95638466dc 100644 --- a/trl/experimental/sdft/sdft_config.py +++ b/trl/experimental/sdft/sdft_config.py @@ -12,22 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2020-2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from dataclasses import dataclass, field +import transformers +from packaging.version import Version from transformers import TrainingArguments @@ -36,8 +24,8 @@ class SDFTConfig(TrainingArguments): r""" Configuration class for the [`SDFTTrainer`]. - This class includes only the parameters that are specific to Self-Distillation Fine-Tuning (SDFT). For a full - list of training arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that + This class includes only the parameters that are specific to Self-Distillation Fine-Tuning (SDFT) training. For a + full list of training arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may differ from those in [`~transformers.TrainingArguments`]. Using [`~transformers.HfArgumentParser`] we can turn this class into @@ -56,11 +44,7 @@ class SDFTConfig(TrainingArguments): > Parameters that control the data preprocessing - remove_unused_columns (`bool`, *optional*, defaults to `False`): - Whether to only keep the columns needed by SDFT (`"prompt"` and `"teacher_prompt"`) in the dataset. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. - num_generations (`int` or `None`, *optional*, defaults to `8`): + num_generations (`int`, *optional*, defaults to `8`): Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size * gradient_accumulation_steps) must be evenly divisible by this value. max_completion_length (`int` or `None`, *optional*, defaults to `256`): @@ -87,12 +71,19 @@ class SDFTConfig(TrainingArguments): top_p (`float`, *optional*, defaults to `1.0`): Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to `1.0` to consider all tokens. - top_k (`int`, *optional*): - Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is + top_k (`int`, *optional*, defaults to `0`): + Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, top-k-filtering is disabled and all tokens are considered. min_p (`float`, *optional*): Minimum token probability, which will be scaled by the probability of the most likely token. It must be a value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or + `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the + generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. + chat_template_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the `apply_chat_template` function when generating completions. repetition_penalty (`float`, *optional*, defaults to `1.0`): Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat @@ -103,11 +94,6 @@ class SDFTConfig(TrainingArguments): parameter is only effective when `use_vllm` is set to `False`. cache_implementation (`str`, *optional*): Implementation of the cache method for faster generation when `use_vllm` is set to `False`. - generation_kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or - `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the - generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict - with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. > Parameters that control generation acceleration powered by vLLM @@ -126,6 +112,8 @@ class SDFTConfig(TrainingArguments): Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model implementation. + vllm_structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) @@ -139,6 +127,9 @@ class SDFTConfig(TrainingArguments): vllm_server_timeout (`float`, *optional*, defaults to `240.0`): Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the timeout, a `ConnectionError` is raised. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) @@ -146,13 +137,16 @@ class SDFTConfig(TrainingArguments): Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. + vllm_max_model_length (`int`, *optional*): + Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus + `max_completion_length`; if omitted, it is inferred from the model config. vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when launching the vLLM server via the `--vllm_tensor_parallel_size` flag. vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): - Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken - for weight sync and generation. + Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory usage low, but + waking the engine adds host–device transfer latency. > Parameters that control the training @@ -185,25 +179,39 @@ class SDFTConfig(TrainingArguments): `1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with `mask_truncated_completions=True`, only tokens from non-truncated completions are considered. vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`): - Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed - logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL - Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework - (such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation - and training backends. TIS is proposed as a remedy for this issue. - vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): - Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance - sampling ratio, improving training stability. + Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM completion logprobs and + recomputed training logprobs. If set to `False`, no IS is applied regardless of + `vllm_importance_sampling_mode`. When `True`, the selected mode determines how the IS ratios are computed + and constrained. + vllm_importance_sampling_mode (`str`, *optional*, defaults to `"sequence_mask"`): + Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`. Possible + values are: + + - `"token_truncate"`: Token-level truncated IS (default). Per-token ratios are clipped from above at C. + - `"token_mask"`: Token-level masked IS. Per-token ratios above C are set to zero. + - `"sequence_truncate"`: Sequence-level truncated IS. A single sequence ratio is clipped from above at + C and applied to all tokens in the sequence. + - `"sequence_mask"`: Sequence-level masked IS. Sequences with ratios above C are masked out. + vllm_importance_sampling_cap (`float`, *optional*, defaults to `3.0`): + Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes, importance + ratios are clipped from above at C. For `*_mask` modes, ratios larger than C are set to zero. > Parameters that control the logging log_completions (`bool`, *optional*, defaults to `False`): Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, - it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. + it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or + `trackio`. num_completions_to_print (`int`, *optional*): Number of completions to print with `rich`. If `None`, all completions are logged. - wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): - Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all prompts - are logged. + log_unique_prompts (`bool`, *optional*, defaults to `False`): + Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all prompts are + logged. + log_completions_hub_repo (`str`, *optional*): + Hugging Face Hub repository to save the completions. Should be a complete repository name like + `'username/reponame'` or `'orgname/reponame'`, or just `'reponame'` in which case the repository will be + created in the currently-logged-in Hugging Face user's namespace. Note that this repository will be public + unless you set `hub_private_repo=True` or your organization's default is to create private repositories." """ _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] @@ -234,6 +242,16 @@ class SDFTConfig(TrainingArguments): "`fp16` is not set." }, ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) # Parameters that control the model and reference model model_init_kwargs: dict | str | None = field( @@ -252,21 +270,6 @@ class SDFTConfig(TrainingArguments): ) # Parameters that control the data preprocessing - # The default value remove_unused_columns is overwritten from the parent class, because SDFT relies on custom - # columns like `teacher_prompt` (and sometimes multimodal inputs). - remove_unused_columns: bool | None = field( - default=False, - metadata={ - "help": "Whether to only keep the columns 'prompt' and 'teacher_prompt' in the dataset. If you use any " - "additional columns (e.g., images), you should keep this to `False`." - }, - ) - max_prompt_length: int | None = field( - default=512, - metadata={ - "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." - }, - ) num_generations: int | None = field( default=8, metadata={ @@ -315,10 +318,10 @@ class SDFTConfig(TrainingArguments): "Set to 1.0 to consider all tokens." }, ) - top_k: int | None = field( - default=None, + top_k: int = field( + default=0, metadata={ - "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, " + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, " "top-k-filtering is disabled and all tokens are considered." }, ) @@ -338,6 +341,13 @@ class SDFTConfig(TrainingArguments): "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." }, ) + chat_template_kwargs: dict | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating " + "completions." + }, + ) repetition_penalty: float = field( default=1.0, metadata={ @@ -388,10 +398,15 @@ class SDFTConfig(TrainingArguments): vllm_enable_sleep_mode: bool = field( default=False, metadata={ - "help": "Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step " - "and woken for weight sync and generation." + "help": "Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory " + "usage low, but waking the engine adds host–device transfer latency." }, ) + vllm_structured_outputs_regex: str | None = field( + default=None, + metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."}, + ) + # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) vllm_server_base_url: str | None = field( default=None, @@ -415,6 +430,13 @@ class SDFTConfig(TrainingArguments): "after the timeout, a `ConnectionError` is raised." }, ) + vllm_group_port: int = field( + default=51216, + metadata={ + "help": "Port number for the weight update group. This is used to communicate with the vLLM server. " + "Unless the port is occupied, there is no need to change it.", + }, + ) # Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) vllm_gpu_memory_utilization: float = field( @@ -425,6 +447,13 @@ class SDFTConfig(TrainingArguments): "launching the vLLM server via the `--vllm_gpu_memory_utilization` flag." }, ) + vllm_max_model_length: int | None = field( + default=None, + metadata={ + "help": "Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus " + "`max_completion_length`; if omitted, it is inferred from the model config." + }, + ) vllm_tensor_parallel_size: int = field( default=1, metadata={ @@ -442,21 +471,6 @@ class SDFTConfig(TrainingArguments): "improving training speed." }, ) - alpha: float = field( - default=0.0, - metadata={ - "help": "Alpha coefficient. If `0.0` (default), the forward KL is used. If `1.0`, the reverse KL is used. If anything in between, the Jensen-Shannon Divergence is used." - }, - ) - generate_from_teacher: bool = field( - default=False, - metadata={ - "help": "If True, use the teacher model (ref_model) for generation. vLLM will be initialized with teacher " - "weights, enabling fast generation from the teacher. This makes training equivalent to online SFT " - "where the teacher generates completions and the student learns to reproduce them. " - "If False (default), use the student model for generation (standard RL behavior)." - }, - ) num_iterations: int = field( default=1, metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, @@ -501,29 +515,43 @@ class SDFTConfig(TrainingArguments): "non-truncated completions are considered." }, ) - num_loss_tokens_to_skip: int = field( - default=0, + vllm_importance_sampling_correction: bool = field( + default=True, metadata={ - "help": "Number of tokens at the beginning of each completion to exclude from the loss calculation. " - "This can be useful to avoid penalizing the model for the initial tokens of the response, which may be " - "less predictable. A value of `0` (default) means all completion tokens are included in the loss." + "help": "Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM " + "completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied " + "regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines how " + "IS ratios are computed and constrained." }, ) - vllm_importance_sampling_correction: bool = field( - default=True, + + vllm_importance_sampling_mode: str = field( + default="sequence_mask", metadata={ - "help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and " - "recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL " - "Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy " - "effects due to subtle implementation differences between generation and training backends. TIS is " - "proposed as a remedy for this issue." + "help": "Specifies how Importance Sampling (IS) is performed when " + "vllm_importance_sampling_correction=True. Modes are defined along two orthogonal " + "dimensions: (1) constraint, which determines how to handle ratios above " + "vllm_importance_sampling_cap (C)—either truncation (clip from above, ρ ← min(ρ, C)) or " + "masking (set ratios above C to zero); and (2) granularity, which determines whether " + "ratios are computed per token or as a single sequence-level ratio applied to all tokens. " + "Supported options are: 'token_truncate', 'token_mask', 'sequence_truncate', and " + "'sequence_mask'." }, ) + vllm_importance_sampling_cap: float = field( - default=2.0, + default=3.0, metadata={ - "help": "Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the " - "importance sampling ratio, improving training stability." + "help": "Importance sampling cap C used by `vllm_importance_sampling_mode`. For '*_truncate' modes, " + "ratios are clipped from above at C. For '*_mask' modes, ratios larger than C are set to zero." + }, + ) + use_bias_correction_kl: bool = field( + default=False, + metadata={ + "help": "Whether to use the unbiased KL divergence estimator with importance sampling correction. This " + "corrects the KL divergence estimate by multiplying it with the importance sampling ratio. " + "This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556)." }, ) @@ -539,19 +567,45 @@ class SDFTConfig(TrainingArguments): default=None, metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, ) - wandb_log_unique_prompts: bool | None = field( + log_unique_prompts: bool = field( default=False, metadata={ - "help": "Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, " - "all prompts are logged." + "help": "Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all " + "prompts are logged." + }, + ) + log_completions_hub_repo: str | None = field( + default=None, + metadata={ + "help": "Hugging Face Hub repository to save the completions. Should be a complete repository name like " + "`'username/reponame'` or `'orgname/reponame'`, or just `'reponame'` in which case the repository will " + "be created in the currently-logged-in Hugging Face user's namespace. Note that this repository will be " + "public unless you set `hub_private_repo=True` or your organization's default is to create private " + "repositories." }, ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {} + self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + super().__post_init__() + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + + if self.log_completions_hub_repo is not None and not self.log_completions: + raise ValueError( + "log_completions_hub_repo is set, but log_completions is False. Enable log_completions to upload " + "completions to the Hub, or unset log_completions_hub_repo." + ) + num_processes = self.world_size # The current default effective batch size if self.generation_batch_size is None and self.steps_per_generation is None: @@ -575,11 +629,14 @@ def __post_init__(self): ) if self.do_eval and self.eval_strategy != "no": + # Determine the number of generations to use for evaluation + num_generations = self.num_generations_eval or self.num_generations + # Just ensure the value is divisible by the global batch size - if (self.per_device_eval_batch_size * num_processes) % self.num_generations != 0: + if (self.per_device_eval_batch_size * num_processes) % num_generations != 0: raise ValueError( f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " - f"divisible by num_generations ({self.num_generations})." + f"divisible by the number of generations used for evaluation ({num_generations})." ) # The generation batch must contain full prompt groups (no partials), so it must be divisible by @@ -589,3 +646,12 @@ def __post_init__(self): f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " f"({self.num_generations})." ) + + if self.num_generations < 2: + raise ValueError( + "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " + f"{self.num_generations}, which is less than the minimum required." + ) + + if self.delta is not None and self.use_liger_kernel: + raise ValueError("Liger kernel does not support two-sided GRPO loss yet.") diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 7a999028cfe..6b38496fcf0 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -12,166 +12,112 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2020-2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +import asyncio +import atexit +import copy +import importlib.resources as pkg_resources import inspect import os +import sys +import textwrap +import time +import warnings from collections import defaultdict, deque +from collections.abc import Callable from contextlib import nullcontext from functools import partial from pathlib import Path -from pprint import pformat -from typing import Any, Optional +from typing import Any import datasets +import pandas as pd import torch import torch.utils.data import transformers -from accelerate import logging -from accelerate.state import AcceleratorState -from accelerate.utils import broadcast_object_list, gather_object, is_peft_model, set_seed +from accelerate.logging import get_logger +from accelerate.utils import gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset +from huggingface_hub import CommitScheduler, DatasetCard, DatasetCardData, create_repo +from packaging.version import Version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.nn.functional import kl_div, log_softmax from torch.utils.data import DataLoader, Sampler from transformers import ( - AutoConfig, + AutoModelForSequenceClassification, AutoProcessor, + AutoTokenizer, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_trackio_available, is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available -from ...data_utils import is_conversational, maybe_apply_chat_template, prepare_multimodal_messages +from ...chat_template_utils import add_response_schema, get_training_chat_template, parse_response +from ...data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages +) from ...extras.profiling import profiling_context, profiling_decorator -from ...generation.vllm_client import VLLMClient -from ...import_utils import is_vllm_available -from ...models.utils import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ...generation.vllm_generation import VLLMGeneration +from ...import_utils import is_jmespath_available +from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ...models.utils import _ForwardRedirection, disable_gradient_checkpointing from ...trainer.base_trainer import BaseTrainer +from ...trainer.callbacks import SyncRefModelCallback +from .sdft_config import SDFTConfig from ...trainer.utils import ( RepeatSampler, + create_model_from_path, disable_dropout_in_model, - ensure_master_addr_port, entropy_from_logits, + get_config_model_id, identity, nanmax, nanmin, + nanstd, pad, + print_prompt_completions_sample, selective_log_softmax, shuffle_sequence_dict, + shutdown_event_loop_in_daemon, split_pixel_values_by_grid, split_tensor_dict, + start_event_loop_in_daemon, unsplit_pixel_values_by_grid, + use_adapter, ) -from ..utils import prepare_peft_model -from .sdft_config import SDFTConfig if is_peft_available(): - from peft import PeftConfig, PeftModel + from peft import PeftConfig, PeftModel, get_peft_model -if is_vllm_available(): - from vllm import LLM, SamplingParams if is_wandb_available(): import wandb +if is_trackio_available(): + import trackio -logger = logging.get_logger(__name__) - - -class MemoryEfficientSyncRefModelCallback(TrainerCallback): - """ - Memory-efficient callback to synchronize the model with a reference model. - - Unlike the default SyncRefModelCallback, this version iterates through parameters - one at a time instead of gathering all parameters at once. This reduces peak memory - usage from O(full_model_size) to O(single_param_size), making it feasible to sync - large models with DeepSpeed ZeRO-3. - """ - - def __init__( - self, - ref_model: PreTrainedModel | nn.Module, - accelerator: Any | None, - ): - self.accelerator = accelerator - self.ref_model = ref_model - - @staticmethod - def _sync_param(model_param, ref_param, alpha): - """Sync a single parameter: ref = alpha * model + (1 - alpha) * ref""" - ref_param.data.mul_(1.0 - alpha).add_(model_param.data, alpha=alpha) - - @staticmethod - def sync_target_model_memory_efficient(model, target_model, alpha): - """ - Sync target_model to track model, gathering one parameter at a time. - - This is O(1) in memory overhead instead of O(N) where N is model size. - """ - deepspeed_plugin = AcceleratorState().deepspeed_plugin - is_zero3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - - if is_zero3: - import deepspeed - - # Iterate through parameters one at a time - for (name, model_param), (_, ref_param) in zip( - model.named_parameters(), target_model.named_parameters(), strict=False - ): - # Gather only this pair of parameters - with deepspeed.zero.GatheredParameters([model_param, ref_param], modifier_rank=0): - if deepspeed.comm.get_rank() == 0: - MemoryEfficientSyncRefModelCallback._sync_param(model_param, ref_param, alpha) - else: - # Non-ZeRO-3: just iterate normally - for model_param, ref_param in zip(model.parameters(), target_model.parameters(), strict=False): - MemoryEfficientSyncRefModelCallback._sync_param(model_param, ref_param, alpha) - - def on_step_end(self, args, state, control, **kwargs): - model: PreTrainedModel = kwargs["model"] +logger = get_logger(__name__) - if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: - if self.accelerator: - model = self.accelerator.unwrap_model(model) - self.sync_target_model_memory_efficient(model, self.ref_model, args.ref_model_mixup_alpha) class SDFTTrainer(BaseTrainer): """ - Trainer for the Self-Distillation method of Language Models. This algorithms is described - in the paper [Self-Distillation for Language Models](https://arxiv.org/pdf/2601.19897) + Trainer for the Self-Distillation Fine-Tuning (SDFT) method. This algorithm was initially proposed in the + paper [Self-Distillation Enables Continual Learning](https://huggingface.co/papers/2601.19897). Example: ```python from datasets import Dataset - from transformers import AutoModelForCausalLM, AutoTokenizer - from trl.experimental.sdft import SDFTConfig, SDFTTrainer - - student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct") - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - tokenizer.pad_token = tokenizer.eos_token + from trl.experimental.sdft import SDFTTrainer dataset = Dataset.from_dict( { @@ -180,43 +126,34 @@ class SDFTTrainer(BaseTrainer): } ) - training_args = SDFTConfig(output_dir="sdft-model", per_device_train_batch_size=1) trainer = SDFTTrainer( - model=student_model, - ref_model=teacher_model, - args=training_args, - processing_class=tokenizer, + model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset, ) - trainer.train() ``` Args: - model (`Union[str, PreTrainedModel]`): + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): Model to be trained. Can be either: - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in - `args.model_init_kwargs`. + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. - ref_model (`Union[str, PreTrainedModel]`, *optional*): - Teacher model used for distillation. If provided as a string, it is loaded with - [`~transformers.AutoModelForCausalLM.from_pretrained`] using `args.model_init_kwargs`. If `None`, the - trainer will instantiate a teacher model from the same checkpoint as the student. + - A [`~peft.PeftModel`] object. Only causal language models are supported. args ([`SDFTConfig`], *optional*): Configuration for this trainer. If `None`, a default configuration is used. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. It must include columns `"prompt"` and `"teacher_prompt"`. Additional - columns are ignored unless used for multimodal inputs (e.g., `image` or `images`). The format of the - samples can be either: + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: - [Standard](dataset_formats#standard): Each sample contains plain text. - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role and content). - eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): Processing class used to process the data. The padding side must be set to "left". If `None`, the @@ -229,56 +166,62 @@ class SDFTTrainer(BaseTrainer): If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] method. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): - A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your - model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + tools (list of `Callable`, *optional*): + A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool + should be a standard Python function with properly type-hinted arguments and return values, and a + Google-style docstring describing its purpose, arguments, and return value. For more details, see: + https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, + type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool + use and that it has been fine-tuned for tool calling. """ _tag_names = ["trl", "sdft"] _name = "SDFT" + _paper = { + "title": "Self-Distillation Enables Continual Learning", + "id": "2601.19897", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{shenfeld2026selfdistillation, + title = {{Self-Distillation Enables Continual Learning}}, + author = {Idan Shenfeld and Mehul Damani and Jonas Hübotter and Pulkit Agrawal}, + year = 2026, + eprint = {arXiv:2601.19897}, + } + """), + } def __init__( self, - model: str | PreTrainedModel, - ref_model: str | PreTrainedModel | None = None, + model: "str | PreTrainedModel | PeftModel", args: SDFTConfig | None = None, train_dataset: Dataset | IterableDataset | None = None, eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), - peft_config: Optional["PeftConfig"] = None, + peft_config: "PeftConfig | None" = None, + tools: list[Callable] | None = None, ): # Args if args is None: - model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model if isinstance(model, str) else get_config_model_id(model.config) model_name = model_name.split("/")[-1] args = SDFTConfig(f"{model_name}-SDFT") - # Models - # Trained model - model_init_kwargs = args.model_init_kwargs or {} + # Model if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None - elif isinstance(dtype, str): # it's a str, but not "auto" - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype - else: - raise ValueError( - "Invalid `dtype` passed to `SDFTConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - # Disable caching if gradient checkpointing is enabled (not supported) - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) else: - model_id = model.config._name_or_path if args.model_init_kwargs is not None: logger.warning( "You passed `model_init_kwargs` to the `SDFTConfig`, but your model is already instantiated. " @@ -293,12 +236,11 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): - model = prepare_peft_model(model, peft_config, args) - # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left") + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -315,10 +257,82 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + + # Tools + if tools: + if not Version(transformers.__version__) >= Version("5.0.0"): + raise ImportError( + "Using tools with SDFTTrainer requires transformers version 5.0.0 or higher. Please use " + "transformers with `pip install --pre transformers` to use this feature." + ) + if not is_jmespath_available(): + raise ImportError( + "Using tools with SDFTTrainer requires the jmespath library for response parsing. Please install " + "it with `pip install jmespath` to use this feature." + ) + self.tools = tools or [] + self._sync_tool_dict = {} + self._async_tool_dict = {} + if self.tools: + for tool in self.tools: + if asyncio.iscoroutinefunction(tool): + self._async_tool_dict[tool.__name__] = tool + else: + self._sync_tool_dict[tool.__name__] = tool + + # Check for async functions to start an event loop on a daemon thread + self._has_async_funcs = any(asyncio.iscoroutinefunction(func) for func in self.tools) + + if self._has_async_funcs: + self.async_loop_thread, self.async_loop, self.async_loop_ready_event = start_event_loop_in_daemon( + name="SDFTTrainer-AsyncLoop" + ) + # wait until the event loop is running in the daemon thread + self.async_loop_ready_event.wait() + atexit.register(shutdown_event_loop_in_daemon, self.async_loop_thread, self.async_loop) + + # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. + # While waiting for broader adoption, we provide this utility function to manually set the response schema for + # known chat templates. + # We need `getattr`` until the base class sets a default None value for response_schema + if tools and not getattr(processing_class, "response_schema", None): + processing_class = add_response_schema(processing_class) + # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template + # isn't, we replace it at initialization with a training-safe, prefix-preserving template. + if tools: + self.chat_template = get_training_chat_template(processing_class) + else: + self.chat_template = None + # Training arguments - self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length - self.num_generations = args.num_generations + self.max_tool_calling_iterations = args.max_tool_calling_iterations or sys.maxsize + self.chat_template_kwargs = args.chat_template_kwargs or {} self.temperature = args.temperature self.top_p = args.top_p self.top_k = args.top_k @@ -329,11 +343,7 @@ def __init__( self.vllm_mode = args.vllm_mode self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode - self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction - self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap self.mask_truncated_completions = args.mask_truncated_completions - self.top_entropy_quantile = args.top_entropy_quantile - self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip # Datasets self.shuffle_dataset = args.shuffle_dataset @@ -349,11 +359,7 @@ def __init__( raise NotImplementedError( "Iterable datasets are not yet supported in SDFTTrainer. Please use a standard dataset instead." ) - self._validate_dataset_columns(train_dataset, "train_dataset") - self._validate_dataset_columns(eval_dataset, "eval_dataset") - # Multi-step - self.num_iterations = args.num_iterations # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle self._step = 0 # Buffer the batch to reuse generated outputs across multiple updates. For more details, see @@ -361,7 +367,7 @@ def __init__( self._buffered_inputs = None # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in GRPO-like algorithms, the sampled data does not include the + # input tensor associated with the key "input_ids". However, in SDFT, the sampled data does not include the # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. @@ -385,43 +391,16 @@ def __init__( compute_loss_func="non-None value to disable scaling", ) - # Reference model - self.beta = args.beta - self.alpha = args.alpha - self.generate_from_teacher = args.generate_from_teacher - if isinstance(ref_model, str): - ref_model_id = ref_model - config = AutoConfig.from_pretrained(ref_model_id) - architecture = getattr(transformers, config.architectures[0]) - ref_model = architecture.from_pretrained(ref_model_id, **model_init_kwargs) - elif ref_model is None: - if not model_id: - raise ValueError( - "SDFTTrainer could not infer a teacher checkpoint from the student model. " - "Please pass `ref_model` explicitly." - ) - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) - elif not isinstance(ref_model, PreTrainedModel): - raise TypeError("`ref_model` must be a model id or a PreTrainedModel instance.") - - self.ref_model = ref_model - self.ref_model.eval() - for param in self.ref_model.parameters(): - param.requires_grad_(False) - # Disable dropout in the models if args.disable_dropout: disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._total_train_tokens = 0 + self._current_train_step_time = 0.0 self.log_completions = args.log_completions - self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.log_unique_prompts = args.log_unique_prompts self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { @@ -436,88 +415,44 @@ def __init__( set_seed(args.seed, device_specific=True) if self.use_vllm: - if not is_vllm_available(): - raise ImportError( - "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install trl[vllm]` to use it." - ) - - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) - - elif self.vllm_mode == "colocate": - # Make sure vllm_tensor_parallel_size group size evenly divides the world size - each group should have - # the same number of ranks - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " - f"({self.accelerator.num_processes}) evenly." - ) - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. - # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ] - ) - - # vLLM requires the environment variables to be set for distributed training. - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - # Ensure distributed rendezvous variables are set without colliding across concurrent runs - ensure_master_addr_port() - - if self.max_prompt_length is not None and self.max_completion_length is not None: - max_model_len = self.max_prompt_length + self.max_completion_length - else: - max_model_len = None - # Use teacher model for vLLM when generate_from_teacher=True - if self.generate_from_teacher and self.ref_model is None: - raise ValueError("`generate_from_teacher=True` requires a teacher model.") - vllm_model_path = self.ref_model.name_or_path if self.generate_from_teacher else model.name_or_path - logger.info( - f"[DEBUG] Initializing vLLM with model: {vllm_model_path}, " - f"generate_from_teacher={self.generate_from_teacher}" - ) - self.llm = LLM( - model=vllm_model_path, - tensor_parallel_size=args.vllm_tensor_parallel_size, - gpu_memory_utilization=self.vllm_gpu_memory_utilization, - max_num_seqs=self.args.per_device_train_batch_size - * self.vllm_tensor_parallel_size - * self.args.steps_per_generation, - max_model_len=max_model_len, - distributed_executor_backend="external_launcher", - # Feed identical seed for tp groups to ensure sampling results are the same across workers - seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory - max_num_batched_tokens=4096, - model_impl=self.args.vllm_model_impl, - enable_sleep_mode=self.args.vllm_enable_sleep_mode, - # Important so temperature scaling/logit tweaking affects the TIS log probs - logprobs_mode="processed_logprobs", - ) - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - else: - raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") - + # Initialize vLLM generation backend + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + # vLLM configuration + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + # Server mode configuration + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + # Colocate mode configuration + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size + * args.vllm_tensor_parallel_size + * args.steps_per_generation, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + # Generation configuration + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + max_completion_length=self.max_completion_length, + generation_kwargs=args.generation_kwargs, + # Chat/tool configuration + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, + ) self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation - - # When using vLLM, the main process is responsible for loading the model weights. This can cause process - # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized. - self.accelerator.wait_for_everyone() else: generation_kwargs = { "max_new_tokens": self.max_completion_length, @@ -535,6 +470,8 @@ def __init__( if args.generation_kwargs is not None: generation_kwargs.update(args.generation_kwargs) self.generation_config = GenerationConfig(**generation_kwargs) + # Keep training-specific generation kwargs to overwrite model's original generation config + self.generation_kwargs = generation_kwargs # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set @@ -544,46 +481,38 @@ def __init__( # Add tags to the model self.model.add_model_tags(self._tag_names) - if self.ref_model is not None: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - elif self.is_fsdp_enabled: - self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - if args.sync_ref_model: - self.add_callback( - MemoryEfficientSyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator) - ) - - def _validate_dataset_columns(self, dataset, name: str) -> None: - if dataset is None: - return - if isinstance(dataset, dict): - for key, sub_dataset in dataset.items(): - self._validate_dataset_columns(sub_dataset, f"{name}[{key}]") - return - if not hasattr(dataset, "column_names"): - return - required_columns = {"prompt", "teacher_prompt"} - existing_columns = set(dataset.column_names) - missing = required_columns - existing_columns - if missing: - missing_list = ", ".join(sorted(missing)) - required_list = ", ".join(sorted(required_columns)) - raise ValueError( - f"{name} must include columns [{required_list}]. Missing [{missing_list}]. " - "If you do not have distinct teacher prompts, set `teacher_prompt` to the same value as `prompt`." - ) + if self.accelerator.is_main_process and self.log_completions: + os.makedirs(os.path.join(self.args.output_dir, "completions"), exist_ok=True) + if self.args.log_completions_hub_repo is not None: + repo_id = self.args.log_completions_hub_repo + create_repo(repo_id, private=self.args.hub_private_repo, repo_type="dataset", exist_ok=True) + template_path = pkg_resources.files("trl").joinpath("templates/completions_dataset_card.md") + card_data = DatasetCardData( + pretty_name="TRL Completion logs", + tags=["trl", "trl-logs", "completions"], + ) + card = DatasetCard.from_template( + card_data=card_data, + template_path=str(template_path), + repo_id=repo_id, + hub_model_id=self.args.hub_model_id, + ) + card.push_to_hub(repo_id) + self.commit_scheduler = CommitScheduler( + repo_id=repo_id, + repo_type="dataset", + folder_path=f"{self.args.output_dir}/completions", + every=2, # minutes + allow_patterns=["*.parquet"], + ) def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. - # By default, this method sets `self._signature_columns` to the model's expected inputs. - # In SDFTTrainer, we preprocess data, so using the model's signature columns doesn't work. - # Instead, we set them to the columns expected by the `training_step` method, hence the override. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). In SDFTTrainer, we preprocess data, so using the model's signature columns doesn't + # work. Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "teacher_prompt", "image", "images"] + self._signature_columns = ["prompt", "image", "images"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -593,7 +522,7 @@ def _set_signature_columns_if_needed(self): # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the # splitting internally. # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line - # modification. As a result, some parts of the method aren't relevant to Distil, but we keep them to stay one line + # modification. As a result, some parts of the method aren't relevant to SDFT, but we keep them to stay one line # apart from the super method, ensuring easier maintenance in the future. def get_train_dataloader(self): if self.train_dataset is None: @@ -628,9 +557,9 @@ def get_train_dataloader(self): def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: # Returns a sampler that # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are - # distributed to different GPUs, allowing group-wise statistics to be computed consistently. Using the - # same seed across processes ensures consistent prompt assignment, preventing discrepancies in group - # formation. + # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt + # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies + # in group formation. # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to # _prepare_inputs to see how the generations are stored and reused. @@ -654,9 +583,9 @@ def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: dataset = self.train_dataset return RepeatSampler( data_source=dataset, - mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.num_iterations * self.args.steps_per_generation, + mini_repeat_count=1, + batch_size=self.args.generation_batch_size, + repeat_count=1 * self.args.steps_per_generation, shuffle=self.shuffle_dataset, seed=self.args.seed, ) @@ -665,104 +594,10 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler: # See _get_train_sampler for an explanation of the sampler. return RepeatSampler( data_source=eval_dataset, - mini_repeat_count=self.num_generations, + mini_repeat_count=1, seed=self.args.seed, ) - def _log_prompt_completions_sample(self) -> None: - if not self.log_completions: - return - num_samples = self.num_completions_to_print or len(self._logs["prompt"]) - for idx, (prompt, completion) in enumerate(zip(self._logs["prompt"], self._logs["completion"], strict=False)): - if idx >= num_samples: - break - prompt_text = pformat(prompt, width=100) - completion_text = pformat(completion, width=100) - logger.info("SDFT sample %s\nPrompt:\n%s\nCompletion:\n%s", idx, prompt_text, completion_text) - - @profiling_decorator - def _get_last_hidden_state( - self, - unwrapped_model, - input_ids, - attention_mask, - logits_to_keep, - pixel_values=None, - image_grid_thw=None, - pixel_attention_mask=None, - image_sizes=None, - ): - if is_peft_model(unwrapped_model): - unwrapped_model = unwrapped_model.base_model.model - - # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - - # For Qwen models: - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw - # For Gemma, SmolVLM2, LLaVa-Next etc.: - if pixel_values is not None: - model_inputs["pixel_values"] = pixel_values - # For SmolVLM2 - if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask - # For LLaVa-Next - if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes - - # Only add logits_to_keep if the model supports it - if "logits_to_keep" in self.model_kwarg_keys: - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - model_inputs["logits_to_keep"] = logits_to_keep + 1 - - model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings - - last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state - # Exclude the last value: it corresponds to the next token pred - last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) - # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. - last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) - return last_hidden_state - - def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: - """ - Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. - - Args: - entropies (`torch.Tensor`): - Tensor of shape (batch_size, seq_len) with per-token entropy values. - mask (`torch.Tensor`): - Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding. - threshold (`float`): - Quantile threshold between `0.0` and `1.0` to select high-entropy tokens. - - Returns: - `torch.Tensor`: - Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold - and `False` otherwise. - """ - local = entropies[mask.bool()].float() - - # Use a negative pad_value as a sentinel because entropy values are always >= 0. - # This guarantees that the sentinel cannot collide with any real entropy value. - pad_value = -1e9 - - # Pad across processes so that every rank has the same tensor length - padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) - gathered = self.accelerator.gather(padded) - - # Drop sentinel values (safe because no entropy can be negative) - gathered = gathered[gathered != pad_value] - - if gathered.numel() == 0: - return torch.zeros_like(entropies, dtype=torch.bool) - - entropy_threshold = torch.quantile(gathered, threshold) - masked_entropies = entropies * mask.float() - entropy_mask = masked_entropies >= entropy_threshold - return entropy_mask & mask.bool() # ensure padding tokens are always masked out - @profiling_decorator def _get_per_token_logps_and_entropies( self, @@ -778,11 +613,9 @@ def _get_per_token_logps_and_entropies( pixel_attention_mask=None, image_sizes=None, token_type_ids=None, - compute_all_logps=True, ) -> dict[str, torch.Tensor | None]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak - all_selected_logps = [] all_logps = [] all_entropies = [] for start in range(0, input_ids.size(0), batch_size): @@ -825,14 +658,8 @@ def _get_per_token_logps_and_entropies( # Divide logits by sampling temperature. # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details logits = logits / self.temperature - completion_ids = input_ids_batch[:, -logits_to_keep:] - selected_logps = selective_log_softmax(logits, completion_ids) # compute logprobs - if compute_all_logps: - logps = log_softmax(logits, dim=-1) - else: - logps = None - all_selected_logps.append(selected_logps) + logps = selective_log_softmax(logits, completion_ids) # compute logprobs all_logps.append(logps) if compute_entropy: @@ -840,146 +667,20 @@ def _get_per_token_logps_and_entropies( entropies = entropy_from_logits(logits) all_entropies.append(entropies) - selected_logps = torch.cat(all_selected_logps, dim=0) - if compute_all_logps: - logps = torch.cat(all_logps, dim=0) - else: - logps = None + logps = torch.cat(all_logps, dim=0) entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None - return selected_logps, logps, entropies - - def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None): - extra_prefixes = extra_prefixes or [] - prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes - for prefix in prefixes: - name = name.replace(prefix, "") - return name - - def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" - # For FSDP1, we need to recurse into children and also use summon_full_params - if visited is None: - visited = set() - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - self._sync_fsdp1_params_to_vllm( - child_module, prefix=child_prefix, visited=visited - ) # recurse into the child - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(full_name, param.data)]) - - def _sync_fsdp2_params_to_vllm(self, module: nn.Module): - # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion - for name, param in module.state_dict().items(): - if param.is_cpu: - param = param.to(torch.device("cuda")) - param = param.full_tensor() - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param)]) - - @profiling_decorator - def _move_model_to_vllm(self): - # Select which model to sync to vLLM: teacher (ref_model) or student (model) - # When generate_from_teacher=True, sync the teacher model since vLLM was initialized with teacher weights - model_to_sync = self.ref_model if self.generate_from_teacher else self.model - - # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - - if is_peft_model(self.model): - if self.generate_from_teacher: - raise ValueError( - "PEFT model handling only applies when syncing student model (teacher is typically not PEFT)" - ) - # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as - # merging adapters in a sharded manner is not supported. - # TODO: does this work with FSDP? - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext - # Update vLLM weights while parameters are gathered - # For PEFT with FSDP we need to use the memory efficient post-order traversal - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm( - self.model - ) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - # DeepSpeed ZeRO-3 with PEFT - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather (if needed) and update each parameter individually. - if self.is_fsdp_enabled: - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm( - model_to_sync - ) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(model_to_sync) - else: - for name, param in model_to_sync.named_parameters(): - name = self._fix_param_name_to_vllm(name) - with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - - # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - elif self.vllm_mode == "colocate": - self.llm.reset_prefix_cache() + return logps, entropies + + def training_step(self, model, inputs, num_items_in_batch): + time_before = time.perf_counter() + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + time_after = time.perf_counter() + self._current_train_step_time += time_after - time_before + if self._step % self.current_gradient_accumulation_steps == 0: + self._metrics["train"]["step_time"].append(self._current_train_step_time) + self._current_train_step_time = 0.0 + return output @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: @@ -990,7 +691,7 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di # - Generates completions once for the entire generation batch and splits it into batches of size # `per_device_train_batch_size` # - Buffers these completions and returns the appropriate slice for the current accumulation step - # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # - Optimizes by regenerating completions only periodically (every steps_per_generation) # During evaluation: # - The input is treated as a standard local batch (no accumulation, no multiple iterations) # - Completions are generated for each batch without buffering or reuse @@ -998,195 +699,52 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di mode = "train" if self.model.training else "eval" if mode == "train": - generate_every = self.args.steps_per_generation * self.num_iterations + generate_every = self.args.steps_per_generation if self._step % generate_every == 0 or self._buffered_inputs is None: # self._buffered_inputs=None can occur when resuming from a checkpoint - generation_batch = self._generate_and_score_completions(generation_batch) + generation_batch = self._generate_completions(generation_batch) generation_batch = split_pixel_values_by_grid(generation_batch) generation_batch = shuffle_sequence_dict(generation_batch) generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] - self._step += 1 else: # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence # local generation batch == local eval batch - inputs = self._generate_and_score_completions(generation_batch) + inputs = self._generate_completions(generation_batch) return inputs - def _generate_single_turn(self, prompts: list[str], images: list | None): + def _generate_single_turn(self, prompts: list): device = self.accelerator.device - - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images, strict=False): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] - - if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} + mode = "train" if self.model.training else "eval" # Generate completions using either vLLM or regular generation - # Note: When generate_from_teacher=True, vLLM is initialized with teacher weights if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() - - # First, update the vLLM weights if needed - # When generate_from_teacher=True and sync_ref_model=False, teacher is static so no sync needed - # (vLLM already loaded teacher weights at initialization) - should_sync = self.state.global_step != self._last_loaded_step - if self.generate_from_teacher and not self.args.sync_ref_model: - should_sync = False # Teacher is static, no need to sync - if should_sync: - self._move_model_to_vllm() + # Sync weights if training step changed + if self.state.global_step != self._last_loaded_step: + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - generation_kwargs=self.args.generation_kwargs, - ) - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] - - # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times - all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - prompt_ids = all_prompt_ids[process_slice] - completion_ids = all_completion_ids[process_slice] - logprobs = all_logprobs[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "truncate_prompt_tokens": self.max_prompt_length, - "logprobs": 0, # only return the logprob of the generated token - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images, strict=False): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text - - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - - all_prompt_ids = [output.prompt_token_ids for output in all_outputs] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - prompt_ids = all_prompt_ids[tp_slice] - completion_ids = all_completion_ids[tp_slice] - logprobs = all_logprobs[tp_slice] - else: - prompt_ids = all_prompt_ids - completion_ids = all_completion_ids - logprobs = all_logprobs - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) + # Generate using vLLM + prompt_ids, completion_ids, logprobs, _ = self.vllm_generation.generate( + prompts=prompts, num_generations=1, profiler=profiling_context(self, "vLLM.generate") + ) elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" + if is_conversational({"prompt": prompts[0]}): + processor_outputs = self.processing_class.apply_chat_template( + conversation=prompts, + tools=self.tools, + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **self.chat_template_kwargs, + ) else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" + processor_outputs = self.processing_class(text=prompts) + with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( @@ -1201,34 +759,43 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): elif self.args.fp16: unwrapped_model.to(torch.float16) with torch.inference_mode(): + # Continuous batching API expects 'inputs' arg only all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False + processor_outputs["input_ids"], generation_config=self.generation_config, progress_bar=False ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = paged_prompt_inputs.input_ids - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn + prompt_ids = processor_outputs["input_ids"] logprobs = None # not used in this case else: # Regular generation path - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - **kwargs, - ) + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, + tools=self.tools, + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + padding=True, + padding_side="left", + return_tensors="pt", + return_dict=True, + **self.chat_template_kwargs, + ) + else: + generate_inputs = self.processing_class( + text=prompts, padding=True, padding_side="left", return_tensors="pt" + ) generate_inputs = super()._prepare_inputs(generate_inputs) with ( profiling_context(self, "transformers.generate"), unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 ) as unwrapped_model, torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), @@ -1247,21 +814,221 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] - completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] logprobs = None # not used in this case - return prompt_ids, completion_ids, logprobs, forward_kwargs + return prompt_ids, completion_ids, logprobs + + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs): + # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt + tool_calls = [completion[0].get("tool_calls") for completion in completions] + idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] + tool_calls = [tool_calls[idx] for idx in idxs_with_tool] + tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere + tool_call_count = 0 + tool_failure_count = 0 + iteration_num = 0 + while idxs_with_tool and iteration_num < self.max_tool_calling_iterations: + prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls + + # Call the tools, and build the new prompt for generation + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + tool_call_list = tool_calls[idx] + prompt_completion_tool = prompt_completion_tools[idx] + # Append the last assistant message (which triggered tool_calls) to the prompt + prompt_completion_tool.append(completions[idx_with_tool][-1]) + async_coros = [] + tool_call_results = [] + for tool_call in tool_call_list: + tool_call_count += 1 + if tool_call["type"] == "function": + function = tool_call["function"] + name = function["name"] + try: + if name in self._sync_tool_dict: + tool_call_results.append((name, self._sync_tool_dict[name](**function["arguments"]))) + elif name in self._async_tool_dict: + async_coros.append((name, self._async_tool_dict[name](**function["arguments"]))) + except Exception as e: + tool_failure_count += 1 + result = {"error": str(e)} + tool_call_results.append((name, result)) + else: + tool_failure_count += 1 + name = tool_call.get("name", "unknown") + tool_call_results.append((name, {"error": f"Unsupported tool call type: {tool_call['type']}"})) + + if async_coros: + + async def _run_async_tools(async_coros): + coros = [coro for _, coro in async_coros] + results = await asyncio.gather(*coros, return_exceptions=True) + return [(name, result) for (name, _), result in zip(async_coros, results, strict=False)] + + async_results = asyncio.run_coroutine_threadsafe( + _run_async_tools(async_coros), self.async_loop + ).result() + + for name, result in async_results: + if isinstance(result, Exception): + tool_failure_count += 1 + tool_call_results.append((name, {"error": str(result)})) + else: + tool_call_results.append((name, result)) + + for name, result in tool_call_results: + tool_message = {"role": "tool", "name": name, "content": str(result)} + prompt_completion_tool.append(tool_message) + completions[idx_with_tool].append(tool_message) + + # Tokenize and filter samples whose length exceeds max allowed length. This is important, because both + # vLLM and transformers will error out if the input is longer than the model's max length. + pct_ids = self.processing_class.apply_chat_template( + prompt_completion_tools, + tools=self.tools, + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=False, + **self.chat_template_kwargs, + ) + if self.use_vllm and self.vllm_mode == "colocate": + max_model_len = self.llm.llm_engine.model_config.max_model_len + elif not self.use_vllm: + max_model_len = self.model.config.max_position_embeddings + else: + raise NotImplementedError( + f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" + ) + overlong = [len(pct) >= max_model_len for pct in pct_ids] + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + if overlong[idx]: + prompt_length = len(prompt_ids[idx_with_tool]) + ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length] + completion_ids[idx_with_tool] = ct + tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) + if logprobs is not None: + logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) + # Keep only non-overlong items for further processing + idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] + prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] + if not idxs_with_tool: + break # all overlong, exit tool loop + + # Generate new completions after tool execution + prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn( + prompt_completion_tools + ) + + # Sanity check: from experience, this is useful to catch bugs in the chat template + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool + if prompt_ids[idx_with_tool] != pct[: len(prompt_ids[idx_with_tool])]: + raise ValueError( + "The chat template is not prefix-preserving. Please update it to use a prefix-preserving " + "format." + ) - def _generate(self, prompts: list[str], images: list | None): + # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_len = len(prompt_ids[idx_with_tool]) + completion_tool_ids = prompt_completion_tool_ids[idx][prompt_len:] + excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length + if excess_length > 0: + # If exceeding max length, truncate post_tool_ids + post_tool_ids[idx] = post_tool_ids[idx][:-excess_length] + if logprobs is not None: + post_tool_logprobs[idx] = post_tool_logprobs[idx][:-excess_length] + excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length + if excess_length > 0: + # If still exceeding max length, truncate completion_tool_ids as well + prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length] + + # Update tool_mask: the tool result should be 0 and the post-tool 1 + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) + prompt_length = len(prompt_ids[idx_with_tool]) + completion_length = len(completion_ids[idx_with_tool]) + post_tool_length = len(post_tool_ids[idx]) + tool_length = prompt_completion_tool_length - prompt_length - completion_length + tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length + if logprobs is not None: + logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] + + # Update completion_ids with the new completions (after tool execution) + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_length = len(prompt_ids[idx_with_tool]) + pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool + completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] + + # Decode post-tool completions + post_tool_completions = [ + parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids + ] + + # Add post-tool completions to the existing completions + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + if post_tool_completions[idx]: # {} if post-tool completions completely truncated + completions[idx_with_tool].append(post_tool_completions[idx]) + + # Check for further tool calls + tool_calls = [completion.get("tool_calls") for completion in post_tool_completions] + idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] + tool_calls = [tool_call for tool_call in tool_calls if tool_call] + iteration_num += 1 + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count + + def _generate(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) + # Copy the prompts to avoid modifying the original list + prompts = copy.deepcopy(prompts) + + prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts) + + # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. + if is_conversational({"prompt": prompts[0]}): + if ( + Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 + and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors + and hasattr(self.processing_class, "response_schema") # attribute not set by default for now + and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + ): + completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + else: + contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in contents] + else: + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Extract tool calls from the completions and (possibly) execute them + if self.tools: + ( + tool_mask, + completions, + completion_ids, + logprobs, + tool_call_count, + tool_failure_count, + ) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs) + else: + tool_mask = None # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) - completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + if tool_mask is not None: # count only non-tool tokens (tool_mask=1) + completion_lengths = torch.tensor([sum(mask) for mask in tool_mask], device=device) + else: + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) agg_prompt_lengths = self.accelerator.gather(prompt_lengths) agg_completion_lengths = self.accelerator.gather(completion_lengths) total_prompt_tokens = agg_prompt_lengths.sum() @@ -1289,16 +1056,32 @@ def _generate(self, prompts: list[str], images: list | None): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs + if self.tools: + agg_tool_call_count = self.accelerator.gather(torch.tensor(tool_call_count, device=device)).sum() + tool_call_frequency = (agg_tool_call_count / len(agg_prompt_lengths)).item() + self._metrics[mode]["tools/call_frequency"].append(tool_call_frequency) + agg_tool_failure_count = self.accelerator.gather(torch.tensor(tool_failure_count, device=device)).sum() + failure_frequency = ( + (agg_tool_failure_count / agg_tool_call_count).item() if agg_tool_call_count > 0 else 0.0 + ) + self._metrics[mode]["tools/failure_frequency"].append(failure_frequency) + + return ( + prompt_ids, + completion_ids, + tool_mask, + completions, + total_completion_tokens, + logprobs + ) - def _generate_and_score_completions( + def _generate_completions( self, inputs: list[dict[str, torch.Tensor | Any]] ) -> dict[str, torch.Tensor | Any]: device = self.accelerator.device mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] - teacher_prompts = [x["teacher_prompt"] for x in inputs] if "images" in inputs[0]: images = [example.get("images") for example in inputs] @@ -1310,67 +1093,29 @@ def _generate_and_score_completions( if images is not None and all(img_list == [] for img_list in images): images = None - # Decide whether to generate from teacher (with context) or student (without context) - generation_prompts = teacher_prompts if self.generate_from_teacher else prompts + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [ + prepare_multimodal_messages(prompt, image_list) + for prompt, image_list in zip(prompts, images, strict=True) + ] ( - _generation_prompt_ids_list, # Discard - we'll compute student/teacher prompt IDs separately + prompt_ids_list, completion_ids_list, + tool_mask_list, + completions, num_items_in_batch, sampling_per_token_logps_list, - forward_kwargs, - ) = self._generate(generation_prompts, images) - - # Process student prompts (always used for student training, regardless of generation source) - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] - if self.use_vllm: - self.processing_class.truncation_side = "left" - student_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - student_inputs = super()._prepare_inputs(student_inputs) - student_prompt_ids, student_prompt_mask = student_inputs["input_ids"], student_inputs["attention_mask"] - prompt_ids_list = [p[m].tolist() for p, m in zip(student_prompt_ids, student_prompt_mask.bool(), strict=False)] - - # Process teacher prompts (always used for teacher, regardless of generation source) - teacher_prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] - for prompt in teacher_prompts - ] - teacher_inputs = self.processing_class( - text=teacher_prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - ) - teacher_inputs = super()._prepare_inputs(teacher_inputs) - if self.use_vllm: - self.processing_class.truncation_side = "right" - teacher_prompt_ids, teacher_prompt_mask = teacher_inputs["input_ids"], teacher_inputs["attention_mask"] - teacher_prompt_ids_list = [ - p[m].tolist() for p, m in zip(teacher_prompt_ids, teacher_prompt_mask.bool(), strict=False) - ] + ) = self._generate(prompts) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - teacher_prompt_ids = [torch.tensor(ids, device=device) for ids in teacher_prompt_ids_list] - teacher_prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in teacher_prompt_ids] - teacher_prompt_ids = pad(teacher_prompt_ids, padding_value=self.pad_token_id, padding_side="left") - teacher_prompt_mask = pad(teacher_prompt_mask, padding_value=0, padding_side="left") completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") @@ -1380,6 +1125,9 @@ def _generate_and_score_completions( sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") else: sampling_per_token_logps = None + if self.tools: + tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") # 0 for tool result tokens, 1 elsewhere # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -1390,95 +1138,36 @@ def _generate_and_score_completions( # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - teacher_prompt_completion_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) # (B, P+C) - teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) # (B, P+C) - # If token_type_ids are used, extend them with zeros for the completion part - if "token_type_ids" in forward_kwargs: - token_type_ids = forward_kwargs["token_type_ids"] - forward_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 - ) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size num_images = [len(img_list) for img_list in images] if images is not None else None - with torch.no_grad(): - # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of - # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the - # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps - # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set - # old_per_token_logps to None. - # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the - # distribution mismatch between vLLM and the training model can be large and harm the training. - # Skip when generate_from_teacher=True since importance sampling is not used in that case. - generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency - if not self.generate_from_teacher and ( - self.args.gradient_accumulation_steps % generate_every != 0 - or (self.use_vllm and self.vllm_importance_sampling_correction) - ): - old_per_token_logps, _, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size, - num_images=num_images, - compute_all_logps=False, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - old_per_token_logps = None - - # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch - # Skip when generate_from_teacher=True since vLLM has teacher weights (no mismatch to correct) - if self.use_vllm and self.vllm_importance_sampling_correction and not self.generate_from_teacher: - importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) - importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=self.vllm_importance_sampling_cap - ) - else: - importance_sampling_ratio = None - - # Compute the per-token log probabilities for the reference model - if self.beta != 0.0: - if self.ref_model is not None: - ref_per_token_logps, _, _ = self._get_per_token_logps_and_entropies( - self.ref_model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - num_images=num_images, - compute_all_logps=False, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ref_per_token_logps, _, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - num_images=num_images, - compute_all_logps=False, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - ref_per_token_logps = None + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template( + {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs + )["prompt"] + for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) # Decode prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text, strict=False): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - completions.append([{"role": "assistant", "content": bootstrap + completion}]) - else: - completions = completions_text # Log prompt and completion texts self._logs["prompt"].extend(gather_object(prompts_text)) @@ -1487,53 +1176,15 @@ def _generate_and_score_completions( if images is not None: self._logs["images"].extend(gather_object(images)) - if importance_sampling_ratio is not None: - delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - delta = delta[completion_mask.bool()] - mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( - self.accelerator.gather(mean_delta).mean().item() - ) - self._metrics[mode]["sampling/sampling_logp_difference/max"].append( - self.accelerator.gather(max_delta).max().item() - ) - - flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] - min_importance_sampling_ratio = ( - torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - mean_importance_sampling_ratio = ( - torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - max_importance_sampling_ratio = ( - torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( - nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( - self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( - nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() - ) - output = { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, "completion_ids": completion_ids, "completion_mask": completion_mask, - "teacher_prompt_ids": teacher_prompt_ids, - "teacher_prompt_mask": teacher_prompt_mask, "num_items_in_batch": num_items_in_batch, } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - if importance_sampling_ratio is not None: - output["importance_sampling_ratio"] = importance_sampling_ratio - if ref_per_token_logps is not None: - output["ref_per_token_logps"] = ref_per_token_logps + if sampling_per_token_logps is not None: + output["sampling_per_token_logps"] = sampling_per_token_logps if "pixel_values" in forward_kwargs: output["pixel_values"] = forward_kwargs["pixel_values"] if "image_grid_thw" in forward_kwargs: @@ -1546,6 +1197,8 @@ def _generate_and_score_completions( output["token_type_ids"] = forward_kwargs["token_type_ids"] if images is not None: output["num_images"] = num_images + if self.tools: + output["tool_mask"] = tool_mask return output @profiling_decorator @@ -1558,28 +1211,13 @@ def _compute_loss(self, model, inputs): # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - teacher_prompt_ids, teacher_prompt_mask = inputs["teacher_prompt_ids"], inputs["teacher_prompt_mask"] - - # Create a separate mask for loss computation that skips the first N tokens - # Note: completion_mask is used for both attention (forward pass) and loss computation - # We need to keep the original for attention, but create a modified one for loss - loss_completion_mask = completion_mask - if self.num_loss_tokens_to_skip > 0: - batch_size, seq_len = completion_mask.shape - # Create a mask that is 0 for the first num_loss_tokens_to_skip tokens and 1 elsewhere - token_positions = torch.arange(seq_len, device=completion_mask.device).unsqueeze(0).expand(batch_size, -1) - skip_mask = (token_positions >= self.num_loss_tokens_to_skip).int() - # Apply the skip mask (only mask tokens that were originally unmasked) - loss_completion_mask = completion_mask * skip_mask - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - teacher_input_ids = torch.cat([teacher_prompt_ids, completion_ids], dim=1) - teacher_attention_mask = torch.cat([teacher_prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + mask = completion_mask if not self.tools else completion_mask * inputs["tool_mask"] # Compute the per_token_logps and the entropy at each position in the completion - per_token_logps, all_logps, entropies = self._get_per_token_logps_and_entropies( + per_token_logps, entropies = self._get_per_token_logps_and_entropies( model, input_ids, attention_mask, @@ -1593,94 +1231,37 @@ def _compute_loss(self, model, inputs): token_type_ids=inputs.get("token_type_ids"), ) - with torch.no_grad(): - teacher_per_token_logps, teacher_all_logps, teacher_entropies = self._get_per_token_logps_and_entropies( - self.ref_model, - teacher_input_ids, - teacher_attention_mask, - logits_to_keep, - compute_entropy=True, - pixel_values=inputs.get("pixel_values"), - image_grid_thw=inputs.get("image_grid_thw"), - num_images=inputs.get("num_images"), - pixel_attention_mask=inputs.get("pixel_attention_mask"), - image_sizes=inputs.get("image_sizes"), - token_type_ids=inputs.get("token_type_ids"), - ) - - if self.top_entropy_quantile < 1.0: - entropy_mask = self.get_high_entropy_mask(entropies, loss_completion_mask, 1 - self.top_entropy_quantile) - else: - entropy_mask = None - - # Compute the KL divergence between the model and the reference model - if self.beta != 0.0: - ref_per_token_logps = inputs["ref_per_token_logps"] - per_token_kl = ( - torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - ) - - # Compute KL divergences using F.kl_div - # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. - if self.alpha == 0: # Forward KL - kl_loss = kl_div(all_logps, teacher_all_logps, reduction="none", log_target=True) - elif self.alpha == 1: # Reverse KL - kl_loss = kl_div(teacher_all_logps, all_logps, reduction="none", log_target=True) - else: - # Compute the log of the mixture distribution - # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture - alpha = torch.tensor(self.alpha, dtype=all_logps.dtype) - mixture_log_probs = torch.logsumexp( - torch.stack([all_logps + torch.log(1 - alpha), teacher_all_logps + torch.log(alpha)]), - dim=0, - ) - - kl_teacher = kl_div(mixture_log_probs, teacher_all_logps, reduction="none", log_target=True) - kl_student = kl_div(mixture_log_probs, all_logps, reduction="none", log_target=True) - - # Compute the Generalized Jensen-Shannon Divergence - kl_loss = alpha * kl_teacher + (1 - alpha) * kl_student - per_token_loss = kl_loss.sum(-1) - - if self.use_vllm and self.vllm_importance_sampling_correction and not self.generate_from_teacher: - ratio = inputs["importance_sampling_ratio"] - importance_weights = (ratio * loss_completion_mask).sum(-1) / loss_completion_mask.sum(-1).clamp(min=1.0) - importance_weights = importance_weights.unsqueeze(-1) - per_token_loss = per_token_loss * importance_weights - - if entropy_mask is not None: - per_token_loss = per_token_loss * entropy_mask + # Compute the loss + log_ratio = per_token_logps - per_token_logps.detach() + log_importance_weights = log_ratio + coef_1 = torch.exp(log_importance_weights) + coef_2 = coef_1 + per_token_loss1 = coef_1 + per_token_loss2 = coef_2 + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) - loss = ((per_token_loss * loss_completion_mask).sum(-1) / loss_completion_mask.sum(-1).clamp(min=1.0)).mean() - loss = loss / self.current_gradient_accumulation_steps - - # Log the metrics mode = "train" if self.model.training else "eval" - with torch.no_grad(): - kl_approx = ( - (per_token_logps - teacher_per_token_logps) + torch.exp(teacher_per_token_logps - per_token_logps) - 1 - ) - kl_approx_mean = (kl_approx * loss_completion_mask).sum() / loss_completion_mask.sum() - self._metrics[mode]["kl_approx"].append(self.accelerator.gather(kl_approx_mean).nanmean().item()) + loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer - loss_completion_token_count = loss_completion_mask.sum().clamp(min=1.0) + # Log the metrics + completion_token_count = mask.sum().clamp(min=1.0) def masked_batch_mean(x): - if x.shape[1] == 1: # already reduced to sequence-level + if x.shape[1] == 1: # when importance_sampling_level == "sequence" return x.mean() else: - return (x * loss_completion_mask).sum() / loss_completion_token_count - - if self.beta != 0.0: - mean_kl = masked_batch_mean(per_token_kl) - self._metrics[mode]["kl_to_base_model"].append(self.accelerator.gather(mean_kl).nanmean().item()) + return (x * mask).sum() / completion_token_count mean_entropy = masked_batch_mean(entropies) self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) return loss + # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and + # returns logits. We override prediction_step to force compute_loss, because this trainer doesn't involve labels. def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): inputs = self._prepare_inputs(inputs) with torch.no_grad(): @@ -1703,27 +1284,54 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() if self.accelerator.is_main_process and self.log_completions: - self._log_prompt_completions_sample() + if is_rich_available(): + print_prompt_completions_sample( + self._logs["prompt"], + self._logs["completion"], + self.state.global_step, + self.num_completions_to_print, + ) + logging_backends = [] if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: - import pandas as pd + logging_backends.append(wandb) + if self.args.report_to and "trackio" in self.args.report_to: + logging_backends.append(trackio) + + table = { + "step": [self.state.global_step] * len(self._logs["prompt"]), + "prompt": self._logs["prompt"], + "completion": self._logs["completion"], + } + + df_base = pd.DataFrame(table) + df_base.to_parquet( + os.path.join( + self.args.output_dir, + "completions", + f"completions_{self.state.global_step:05d}.parquet", + ) + ) - table = { - "step": [str(self.state.global_step)] * len(self._logs["prompt"]), - "prompt": self._logs["prompt"], - "completion": self._logs["completion"], - } + images_raw = self._logs["images"] or [] - if self._logs["images"]: - table["images"] = [] + for logging_backend in logging_backends: + if images_raw: + images = [] for image_list in self._logs["images"]: - # Convert images to wandb Image objects for proper visualization - table["images"].append([wandb.Image(image) for image in image_list]) + images.append([logging_backend.Image(image) for image in image_list]) + df = pd.concat( + [df_base, pd.Series(images, name="image")], + axis=1, + copy=False, + ) + else: + df = df_base - df = pd.DataFrame(table) - if self.wandb_log_unique_prompts: + if self.log_unique_prompts: df = df.drop_duplicates(subset=["prompt"]) - wandb.log({"completions": wandb.Table(dataframe=df)}) + + logging_backend.log({"completions": logging_backend.Table(dataframe=df)}) # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial):