diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4a0f1d0d70d..ccf2f4d528d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -127,6 +127,10 @@ title: PPO - local: prm_trainer title: PRM + - local: sdft_trainer + title: SDFT + - local: sdpo_trainer + title: SDPO - local: winrate_callback title: WinRateCallback - local: xpo_trainer diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 858393120d6..e3f630343e1 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -1630,6 +1630,84 @@ trainer.train() For more details, see the [MiniLLM Trainer documentation](minillm) documentation. +### Reinforcement Learning via Self-Distillation + +**📜 Paper**: https://huggingface.co/papers/2601.20802 + +Self-Distillation Policy Optimization (SDPO) enhances reinforcement learning with verifiable rewards by converting rich textual feedback (e.g., runtime errors, judge evaluations) into a dense learning signal without any external teacher or explicit reward model. SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed next-token predictions back into the policy. Notably, SDPO also outperforms baselines in standard RLVR environments that only return scalar feedback by using successful rollouts as implicit feedback for failed attempts. + +```python +from trl.experimental.sdpo import SDPOConfig, SDPOTrainer + +training_args = SDPOConfig( + distillation_alpha=0.5, # Jensen-Shannon divergence (recommended) + distillation_topk=100, # Top-K logit distillation approximation + full_logit_distillation=True, # Required for top-K logit-level SDPO + distillation_is_clip=2.0, # Importance sampling clipping + distillation_weight=1.0, # Weight for self-distillation loss + sdpo_policy_loss_mode="distillation_only", + use_successful_as_teacher=True, # Use successful rollouts as teacher + teacher_regularization="ema", # Supported: "ema", "none" + teacher_update_rate=0.05, # EMA update rate + include_environment_feedback=False, # Use dataset privileged_context when available +) + +trainer = SDPOTrainer( + model="Qwen/Qwen2.5-1.5B-Instruct", + reward_funcs=..., + args=training_args, + train_dataset=..., +) +trainer.train() +``` + +Expected dataset columns: + +- `prompt` +- `privileged_context` for optional environment feedback + +For more details, see the [SDPO Trainer documentation](sdpo_trainer). + +### Self-Training with On-Policy Self-Distillation for Language Model Alignment + +**📜 Paper**: https://huggingface.co/papers/2601.19897 + +Self-Distilled Fine-Tuning (SDFT) performs on-policy self-distillation by generating completions during training, then distilling an explicit teacher-conditioned view of those same completions back into the student. In TRL, SDFT uses a shared self-distillation core with SDPO where the teacher is the model itself (base weights with adapter disabled for PEFT, or the same model under `no_grad` for non-PEFT). +The teacher prompt is composed internally from the student `prompt` plus the dataset `privileged_context`. + +```python +from datasets import Dataset + +from trl.experimental.sdft import SDFTConfig, SDFTTrainer + +dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Solve 2+2."}]], + "privileged_context": ["Example answer: 4."], + } +) + +training_args = SDFTConfig( + distillation_alpha=0.5, + distillation_topk=5, + max_completion_length=64, +) + +trainer = SDFTTrainer( + model="Qwen/Qwen2.5-1.5B-Instruct", + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +Expected dataset columns: + +- `prompt` +- `privileged_context` containing only the extra teacher-only information + +For more details, see the [SDFT Trainer documentation](sdft_trainer). + ## Distributed Training ### ZeRO: Memory Optimizations Toward Training Trillion Parameter Models diff --git a/docs/source/sdft_trainer.md b/docs/source/sdft_trainer.md new file mode 100644 index 00000000000..7f6b1de018b --- /dev/null +++ b/docs/source/sdft_trainer.md @@ -0,0 +1,81 @@ +# SDFT + +Self-Distilled Fine-Tuning (SDFT) is described in [Self-Training with On-Policy Self-Distillation for Language Model Alignment](https://huggingface.co/papers/2601.19897). + +The TRL implementation adapts SDFT to the experimental trainer API while reusing the shared self-distillation infrastructure also used by SDPO. + +In the current TRL implementation: + +- the teacher is the model itself (base weights with adapter disabled for PEFT, or the same model under `no_grad` for non-PEFT); use `sync_ref_model=True` for an EMA teacher +- the dataset must provide both `prompt` and `privileged_context` +- `privileged_context` contains only the extra teacher-only information; the trainer combines it with `prompt` to build the teacher prompt +- `teacher_prompt_template` controls how `prompt` and `privileged_context` are combined into the teacher prompt +- on-policy generation can use either the student prompt or the teacher-conditioned prompt via `generate_from_teacher` +- `num_loss_tokens_to_skip` can exclude initial completion tokens from the distillation loss +- SDFT currently supports text-only training and does not support `use_vllm=True` +- the shared dataset contract is `prompt` plus `privileged_context` + +## Usage + +```python +from datasets import Dataset + +from trl.experimental.sdft import SDFTConfig, SDFTTrainer + +dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Solve 2+2."}]], + "privileged_context": ["Example answer: 4."], + } +) + +training_args = SDFTConfig( + output_dir="sdft-model", + distillation_alpha=0.5, + distillation_topk=5, + max_completion_length=64, +) + +trainer = SDFTTrainer( + model="Qwen/Qwen2.5-1.5B-Instruct", + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +To generate from the teacher-conditioned prompt instead of the student prompt, set `generate_from_teacher=True`. +To customize how the teacher prompt is built, set `teacher_prompt_template` on [`SDFTConfig`]. + +## Expected dataset columns + +Each example must provide: + +- `prompt`: the student-facing prompt +- `privileged_context`: only the extra teacher-only information, such as a demonstration, hint, or privileged feedback + +Both standard text prompts and conversational prompts are supported by the trainer prompt handling. + +## Callbacks + +The trainer emits a small set of callback hooks that are useful for debugging, observability, and tests. These hooks are intended as practical integration points for experimental self-distillation workflows. + +Shared self-distillation hooks: + +- `on_self_distillation_batch_prepared`: fired when a self-distillation batch is ready. The payload includes `prompt_ids`, `completion_ids`, and `old_per_token_logps` when importance-sampling clipping inputs are available. +- `on_generation_batch_built`: fired when a new buffered generation batch is created. The payload includes `generate_every` and `steps_per_generation`. + +SDFT-specific hook: + +- `on_generation_prompts_selected`: fired when SDFT chooses the prompt source for on-policy generation. The payload includes the selected `generation_prompts` and the corresponding `generation_prompt_text`. + +## SDFTConfig + +[[autodoc]] experimental.sdft.SDFTConfig + +## SDFTTrainer + +[[autodoc]] experimental.sdft.SDFTTrainer + - train + - save_model + - push_to_hub diff --git a/docs/source/sdpo_trainer.md b/docs/source/sdpo_trainer.md new file mode 100644 index 00000000000..11c53588acb --- /dev/null +++ b/docs/source/sdpo_trainer.md @@ -0,0 +1,79 @@ +# SDPO + +Self-Distillation Policy Optimization (SDPO) was introduced in [Reinforcement Learning via Self-Distillation](https://huggingface.co/papers/2601.20802) by [Jonas Hübotter](https://huggingface.co/jonhue), Frederike Lübeck, Lejs Behric, [Anton Baumann](https://huggingface.co/antonbaumann), Marco Bagatella, Daniel Marta, Ido Hakimi, Idan Shenfeld, Thomas Kleine Buening, Carlos Guestrin, and Andreas Krause. + +> Large language models are increasingly post-trained with reinforcement learning in verifiable domains such as code and math. Yet, current methods for reinforcement learning with verifiable rewards (RLVR) learn only from a scalar outcome reward per attempt, creating a severe credit-assignment bottleneck. Many verifiable environments actually provide rich textual feedback, such as runtime errors or judge evaluations, that explain why an attempt failed. We formalize this setting as reinforcement learning with rich feedback and introduce Self-Distillation Policy Optimization (SDPO), which converts tokenized feedback into a dense learning signal without any external teacher or explicit reward model. SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed next-token predictions back into the policy. In this way, SDPO leverages the model's ability to retrospectively identify its own mistakes in-context. Across scientific reasoning, tool use, and competitive programming on LiveCodeBench v6, SDPO improves sample efficiency and final accuracy over strong RLVR baselines. Notably, SDPO also outperforms baselines in standard RLVR environments that only return scalar feedback by using successful rollouts as implicit feedback for failed attempts. Finally, applying SDPO to individual questions at test time accelerates discovery on difficult binary-reward tasks, achieving the same discovery probability as best-of-k sampling or multi-turn conversations with 3x fewer attempts. + +The SDPO trainer is built on TRL's experimental shared self-distillation stack. It keeps the online rollout-and-reward training flow, then builds a teacher-conditioned view of the same completions from successful rollouts and optional environment feedback. + +In the current TRL implementation: + +- the default SDPO policy loss mode is `distillation_only` +- `hybrid` mode is also available to combine the base policy loss with the self-distillation loss +- supported teacher regularization modes are `ema` and `none` +- `distillation_topk` is only valid when `full_logit_distillation=True` +- when `full_logit_distillation=False`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` +- environment feedback can be injected into teacher reprompts when the dataset exposes a `privileged_context` column + +## Expected dataset columns + +Each example must provide: + +- `prompt`: the student-facing prompt +- `privileged_context`: optional privileged text, such as environment feedback, used when `include_environment_feedback=True` + +## Usage + +```python +from datasets import Dataset + +from trl.experimental.sdpo import SDPOConfig, SDPOTrainer + +dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Solve 2+2."}]], + "privileged_context": ["Your earlier answer used the wrong format."], + } +) + +training_args = SDPOConfig( + output_dir="sdpo-model", + distillation_topk=100, # Top-K logit distillation approximation + full_logit_distillation=True, # Required for top-K; enables non-reverse divergences + include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts +) + +trainer = SDPOTrainer( + model="Qwen/Qwen2.5-1.5B-Instruct", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +SDPO always requires a `prompt` column. To use environment feedback, also include a `privileged_context` column and set `include_environment_feedback=True`. SDPO will use successful rollouts and, when enabled, that text to build teacher reprompts for self-distillation. + +## Callbacks + +The trainer emits a small set of callback hooks that are useful for debugging, observability, and tests. These hooks are intended as practical integration points for experimental self-distillation workflows. + +Shared self-distillation hooks: + +- `on_self_distillation_batch_prepared`: fired when a self-distillation batch is ready. The payload includes `prompt_ids`, `completion_ids`, and `old_per_token_logps` when importance-sampling clipping inputs are available. +- `on_generation_batch_built`: fired when a new buffered generation batch is created. The payload includes `generate_every` and `steps_per_generation`. + +SDPO-specific hook: + +- `on_teacher_context_built`: fired after SDPO constructs the teacher-conditioned inputs. The payload includes `teacher_input_ids`, `teacher_attention_mask`, `completion_mask`, and `self_distillation_mask`. + +## SDPOConfig + +[[autodoc]] experimental.sdpo.SDPOConfig + +## SDPOTrainer + +[[autodoc]] experimental.sdpo.SDPOTrainer + - train + - save_model + - push_to_hub diff --git a/tests/experimental/test_sdft_trainer.py b/tests/experimental/test_sdft_trainer.py new file mode 100644 index 00000000000..9ca9b6c579a --- /dev/null +++ b/tests/experimental/test_sdft_trainer.py @@ -0,0 +1,342 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, TrainerCallback, TrainerControl, TrainerState, TrainingArguments +from transformers.utils import is_peft_available + +from trl.data_utils import maybe_apply_chat_template +from trl.experimental.sdft import SDFTConfig, SDFTTrainer + +from ..testing_utils import TrlTestCase, require_peft + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model, get_peft_model_state_dict + + from trl.experimental.self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback + + +class SelfDistillationCaptureCallback(TrainerCallback): + def __init__(self): + self.captured_generation_prompt_text = None + self.captured_old_per_token_logps = None + self.generation_batch_build_count = 0 + + def on_generation_prompts_selected(self, generation_prompt_text=None, **kwargs): + if self.captured_generation_prompt_text is None and generation_prompt_text is not None: + self.captured_generation_prompt_text = generation_prompt_text[0] + + def on_self_distillation_batch_prepared(self, old_per_token_logps=None, **kwargs): + if self.captured_old_per_token_logps is None and old_per_token_logps is not None: + self.captured_old_per_token_logps = old_per_token_logps.detach().cpu() + + def on_generation_batch_built(self, **kwargs): + self.generation_batch_build_count += 1 + + +class TestSDFTTrainer(TrlTestCase): + def test_training_rejects_none_privileged_context(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve 2+2."], + "privileged_context": [None], + } + ) + + training_args = SDFTConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + ) + + trainer = SDFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + with pytest.raises(ValueError, match="`privileged_context` must not be None"): + trainer.train() + + def test_training_with_generate_from_teacher(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve 2+2.", "Solve 3+3."], + "privileged_context": [ + "Teacher hint: answer with 4 and explain briefly.", + "Teacher hint: answer with 6 and explain briefly.", + ], + } + ) + + training_args = SDFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + generate_from_teacher=True, + ) + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + assert capture_callback.captured_generation_prompt_text is not None + assert "Solve 2+2." in capture_callback.captured_generation_prompt_text + assert "Teacher hint" in capture_callback.captured_generation_prompt_text + + def test_training_with_chat_template_kwargs(self): + dataset = Dataset.from_dict( + { + "prompt": [ + [{"role": "user", "content": "Solve 2+2."}], + [{"role": "user", "content": "Solve 3+3."}], + ], + "privileged_context": [ + "Teacher hint: answer with 4.", + "Teacher hint: answer with 6.", + ], + } + ) + + training_args = SDFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + chat_template_kwargs={"enable_thinking": False}, + ) + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDFTTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + expected_prompt = maybe_apply_chat_template( + {"prompt": dataset[0]["prompt"]}, + trainer.processing_class, + **training_args.chat_template_kwargs, + )["prompt"] + + trainer.train() + + assert capture_callback.captured_generation_prompt_text == expected_prompt + + @require_peft + def test_training_with_peft_model(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve 2+2.", "Name the capital of France."], + "privileged_context": [ + "Example answer: 4.", + "Example answer: Paris.", + ], + } + ) + + training_args = SDFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=1, + num_generations=1, + ) + + trainer = SDFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig( + task_type="CAUSAL_LM", + target_modules=["q_proj", "v_proj"], + ), + ) + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + @require_peft + def test_training_with_peft_model_and_sync_ref_model(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve 2+2.", "Name the capital of France."], + "privileged_context": [ + "Example answer: 4.", + "Example answer: Paris.", + ], + } + ) + + training_args = SDFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + max_completion_length=8, + max_steps=2, + num_generations=1, + sync_ref_model=True, + ref_model_mixup_alpha=0.05, + ref_model_sync_steps=1, + ) + + trainer = SDFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig( + task_type="CAUSAL_LM", + target_modules=["q_proj", "v_proj"], + ), + ) + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + @require_peft + def test_peft_adapter_ema_callback(self): + model = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + device_map="cpu", + ) + lora_config = LoraConfig( + task_type="CAUSAL_LM", + target_modules=["q_proj", "v_proj"], + r=8, + ) + model = get_peft_model(model, lora_config, adapter_name="default") + + update_rate = 0.5 + callback = PEFTAdapterEMACallback( + model=model, + teacher_adapter_name="teacher", + update_rate=update_rate, + sync_steps=1, + ) + + # Initialize and verify teacher adapter was created with zero weights + callback._initialize_teacher_adapter() + assert "teacher" in model.peft_config + assert callback.shadow_weights is not None + + teacher_state = get_peft_model_state_dict(model, adapter_name="teacher") + for key, param in teacher_state.items(): + assert torch.all(param == 0), f"Teacher param {key} should be zero-initialized" + + # Verify shadow weights keys match student state dict keys + student_state = {k: v.clone() for k, v in get_peft_model_state_dict(model, adapter_name="default").items()} + assert set(callback.shadow_weights.keys()) == set(student_state.keys()) + + # Simulate a training step and verify EMA update + args = TrainingArguments(output_dir=self.tmp_dir) + state = TrainerState(global_step=1) + control = TrainerControl() + callback.on_step_end(args, state, control) + + # shadow = (1 - rate) * 0 + rate * student = rate * student + for key in callback.shadow_weights: + expected = update_rate * student_state[key] + torch.testing.assert_close(callback.shadow_weights[key], expected) + + # Verify teacher adapter received the shadow weights + teacher_state = get_peft_model_state_dict(model, adapter_name="teacher") + for key in teacher_state: + torch.testing.assert_close(teacher_state[key].float(), callback.shadow_weights[key]) + + def test_training_populates_old_log_probs_for_distillation_clipping_when_misaligned(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve 2+2.", "Solve 3+3."], + "privileged_context": [ + "Example answer: 4.", + "Example answer: 6.", + ], + } + ) + + training_args = SDFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + gradient_accumulation_steps=3, + steps_per_generation=2, + max_completion_length=8, + max_steps=1, + num_generations=1, + ) + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + assert capture_callback.captured_old_per_token_logps is not None + + def test_training_reuses_buffered_generation_batches(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve 2+2.", "Solve 3+3."], + "privileged_context": [ + "Example answer: 4.", + "Example answer: 6.", + ], + } + ) + + training_args = SDFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + steps_per_generation=2, + max_completion_length=8, + max_steps=2, + num_generations=1, + ) + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + assert capture_callback.generation_batch_build_count == 1 diff --git a/tests/experimental/test_sdpo_trainer.py b/tests/experimental/test_sdpo_trainer.py new file mode 100644 index 00000000000..7858442b8b8 --- /dev/null +++ b/tests/experimental/test_sdpo_trainer.py @@ -0,0 +1,417 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from datasets import Dataset, load_dataset +from transformers import TrainerCallback + +from trl.experimental.sdpo import SDPOConfig, SDPOTrainer + +from ..testing_utils import TrlTestCase + + +class SelfDistillationCaptureCallback(TrainerCallback): + def __init__(self): + self.captured_teacher_input_text = None + self.captured_teacher_input_texts = [] + self.captured_self_distillation_mask = None + self.captured_teacher_attention_mask = None + self.captured_completion_mask = None + self.captured_old_per_token_logps = None + + def on_teacher_context_built( + self, + processing_class=None, + teacher_input_ids=None, + teacher_attention_mask=None, + completion_mask=None, + self_distillation_mask=None, + **kwargs, + ): + if self.captured_teacher_input_text is None and teacher_input_ids is not None: + self.captured_teacher_input_text = processing_class.decode(teacher_input_ids[0], skip_special_tokens=True) + if teacher_input_ids is not None: + self.captured_teacher_input_texts.extend( + processing_class.decode(ids, skip_special_tokens=True) for ids in teacher_input_ids + ) + if self.captured_teacher_attention_mask is None and teacher_attention_mask is not None: + self.captured_teacher_attention_mask = teacher_attention_mask.detach().cpu() + if self.captured_completion_mask is None and completion_mask is not None: + self.captured_completion_mask = completion_mask.detach().cpu() + if self.captured_self_distillation_mask is None and self_distillation_mask is not None: + self.captured_self_distillation_mask = self_distillation_mask.detach().cpu() + + def on_self_distillation_batch_prepared(self, old_per_token_logps=None, **kwargs): + if self.captured_old_per_token_logps is None and old_per_token_logps is not None: + self.captured_old_per_token_logps = old_per_token_logps.detach().cpu() + + +class TestSDPOTrainer(TrlTestCase): + def test_training_with_positional_config_argument(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve 2+2."], + "privileged_context": ["Your earlier answer used the wrong format."], + } + ) + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + generation_batch_size=2, + num_generations=2, + max_completion_length=8, + include_environment_feedback=True, + max_steps=1, + ) + + trainer = SDPOTrainer( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + lambda **kwargs: [0.0] * len(kwargs["prompts"]), + training_args, + dataset, + ) + + trainer.train() + + assert trainer.args.output_dir == self.tmp_dir + assert trainer.args.include_environment_feedback is True + assert trainer.state.log_history[-1]["train_loss"] is not None + + def test_training(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + distillation_topk=5, + full_logit_distillation=True, + distillation_is_clip=None, + ) + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Parameter {n} has not changed." + + def test_training_without_successful_rollouts(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + distillation_is_clip=None, + ) + + def zero_reward(**kwargs): + prompts = kwargs["prompts"] + return [0.0] * len(prompts) + + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=zero_reward, + args=training_args, + train_dataset=dataset, + ) + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + def test_training_populates_old_log_probs_for_distillation_clipping_when_misaligned(self): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2.", "Solve 3+3."]}) + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + gradient_accumulation_steps=3, + steps_per_generation=2, + num_generations=2, + max_completion_length=8, + max_steps=1, + ) + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=lambda **kwargs: [0.0] * len(kwargs["prompts"]), + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + assert capture_callback.captured_old_per_token_logps is not None + + def test_evaluation_uses_num_generations_eval_for_teacher_grouping(self): + eval_dataset = Dataset.from_dict({"prompt": ["Alpha prompt", "Beta prompt", "Gamma prompt", "Delta prompt"]}) + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + per_device_eval_batch_size=4, + generation_batch_size=3, + num_generations=3, + num_generations_eval=2, + max_completion_length=8, + success_reward_threshold=0.5, + dont_reprompt_on_self_success=False, + distillation_is_clip=None, + max_steps=1, + ) + + def eval_rewards(**kwargs): + prompts = kwargs["prompts"] + if len(prompts) == 4 and prompts.count("Alpha prompt") == 2 and prompts.count("Beta prompt") == 2: + return [1.0, 0.0, 0.0, 0.0] + return [0.0] * len(prompts) + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=eval_rewards, + args=training_args, + train_dataset=eval_dataset.select(range(1)), + eval_dataset=eval_dataset, + callbacks=[capture_callback], + ) + + trainer.evaluate() + + assert capture_callback.captured_teacher_input_texts + alpha_teachers = [text for text in capture_callback.captured_teacher_input_texts if "Alpha prompt" in text] + beta_teachers = [text for text in capture_callback.captured_teacher_input_texts if "Beta prompt" in text] + assert alpha_teachers + assert beta_teachers + assert any("Correct solution:" in text for text in alpha_teachers) + assert all("Correct solution:" not in text for text in beta_teachers) + + def test_teacher_reprompt_preserves_curly_braces_in_solution_and_feedback(self): + dataset = Dataset.from_dict( + { + "prompt": ["Solve f(x) = {x^2}."], + "privileged_context": ['Feedback: use {"x": 2} as a check.'], + } + ) + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + generation_batch_size=2, + num_generations=2, + max_completion_length=8, + include_environment_feedback=True, + success_reward_threshold=0.5, + dont_reprompt_on_self_success=False, + max_steps=1, + ) + + def reward_with_one_success(**kwargs): + prompts = kwargs["prompts"] + return [1.0, 0.0][: len(prompts)] + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_with_one_success, + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + assert capture_callback.captured_teacher_input_text is not None + assert "{{" not in capture_callback.captured_teacher_input_text + assert "}}" not in capture_callback.captured_teacher_input_text + + def test_training_with_conversational_prompts_preserves_context(self): + dataset = Dataset.from_dict( + { + "prompt": [ + [ + {"role": "system", "content": "You are a careful assistant."}, + {"role": "user", "content": "Solve 2+2."}, + ] + ] + } + ) + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + generation_batch_size=2, + num_generations=2, + max_completion_length=8, + distillation_is_clip=None, + success_reward_threshold=0.5, + max_steps=1, + ) + + def first_only_reward(**kwargs): + """Only the first sample in each group succeeds — exercises dont_reprompt_on_self_success default.""" + return [1.0, 0.0][: len(kwargs["prompts"])] + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=first_only_reward, + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + # With dont_reprompt_on_self_success=True (default), sample 0 skips itself, + # but sample 1 finds sample 0's success and gets a teacher reprompt. + assert capture_callback.captured_teacher_input_text is not None + assert "careful assistant" in capture_callback.captured_teacher_input_text + assert "Solve 2+2" in capture_callback.captured_teacher_input_text + assert capture_callback.captured_self_distillation_mask is not None + + def test_training_with_feedback_only_reprompts_teacher(self): + dataset = Dataset.from_dict( + { + "prompt": [ + [ + {"role": "system", "content": "You are a careful assistant."}, + {"role": "user", "content": "Try the puzzle again."}, + ] + ], + "privileged_context": ["Your earlier answer violated the format requirements."], + } + ) + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + generation_batch_size=2, + num_generations=2, + max_completion_length=8, + distillation_is_clip=None, + include_environment_feedback=True, + max_steps=1, + ) + + def zero_reward(**kwargs): + prompts = kwargs["prompts"] + return [0.0] * len(prompts) + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=zero_reward, + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + assert capture_callback.captured_teacher_input_text is not None + assert "format requirements" in capture_callback.captured_teacher_input_text + assert capture_callback.captured_self_distillation_mask is not None + assert capture_callback.captured_self_distillation_mask[0].item() == 1.0 + + def test_training_warns_when_sdpo_rewards_are_flat(self, caplog): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + diagnostics_warning_interval=2, + max_steps=2, + ) + + def zero_reward(**kwargs): + return [0.0] * len(kwargs["prompts"]) + + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=zero_reward, + args=training_args, + train_dataset=dataset, + ) + + with caplog.at_level(logging.WARNING): + trainer.train() + + assert "Observed flat SDPO rewards across all sampled generations" in caplog.text + assert "SDPO self-distillation is inactive because no reprompted samples were constructed" in caplog.text + + def test_training_preserves_teacher_completion_attention_mask(self): + dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]}) + + training_args = SDPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=1, + generation_batch_size=2, + num_generations=2, + max_completion_length=8, + success_reward_threshold=0.5, + max_steps=1, + ) + + def first_only_reward(**kwargs): + return [1.0, 0.0][: len(kwargs["prompts"])] + + capture_callback = SelfDistillationCaptureCallback() + trainer = SDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=first_only_reward, + args=training_args, + train_dataset=dataset, + callbacks=[capture_callback], + ) + + trainer.train() + + assert capture_callback.captured_teacher_attention_mask is not None + assert capture_callback.captured_completion_mask is not None + + completion_length = capture_callback.captured_completion_mask.shape[1] + teacher_completion_attention = capture_callback.captured_teacher_attention_mask[0, -completion_length:] + assert torch.equal(teacher_completion_attention, capture_callback.captured_completion_mask[0]) diff --git a/trl/experimental/sdft/__init__.py b/trl/experimental/sdft/__init__.py new file mode 100644 index 00000000000..85a7818ae5c --- /dev/null +++ b/trl/experimental/sdft/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .sdft_config import SDFTConfig +from .sdft_trainer import SDFTTrainer + + +__all__ = ["SDFTConfig", "SDFTTrainer"] diff --git a/trl/experimental/sdft/sdft.py b/trl/experimental/sdft/sdft.py new file mode 100644 index 00000000000..8a7d72896d8 --- /dev/null +++ b/trl/experimental/sdft/sdft.py @@ -0,0 +1,457 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +""" +Small-scale SDFT training with Qwen/Qwen3.5-0.8B. + +Expected dataset formats: + +1. Native TRL self-distillation format: + - `prompt` + - `privileged_context` containing only the extra teacher-only information + +2. Demonstration-based format: + - `prompt` + - `golden_response` + +Example: + +```bash +python trl/experimental/sdft/sdft.py \ + --model_name_or_path Qwen/Qwen3.5-0.8B \ + --dataset_name your-org/your-dataset \ + --output_dir outputs/sdft-qwen3.5-0.8b \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --learning_rate 2e-5 \ + --max_prompt_length 1024 \ + --max_completion_length 512 \ + --generate_from_teacher \ + --sync_ref_model \ + --ref_model_sync_steps 1 \ + --ref_model_mixup_alpha 0.01 \ + --eval_strategy steps \ + --eval_steps 50 \ + --report_to wandb +``` +""" + +import json +import os +import re +from dataclasses import dataclass, field +from string import Template +from typing import Any + +import torch +from datasets import DatasetDict, load_dataset, load_from_disk +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from trl import ( + LogCompletionsCallback, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.data_utils import maybe_apply_chat_template +from trl.experimental.sdft import SDFTConfig, SDFTTrainer +from trl.models import unwrap_model_for_generation + + +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +DEFAULT_DEMONSTRATION_TEMPLATE = Template("""Example response: $output_text""") + + +@dataclass +class SDFTScriptArguments(ScriptArguments): + ref_model_name_or_path: str | None = field( + default=None, + metadata={"help": "Reference teacher model. Optional for PEFT runs, where the base model is used as teacher."}, + ) + dataset_path: str | None = field( + default=None, + metadata={"help": "Optional local dataset path to load with `load_from_disk`. Overrides `dataset_name`."}, + ) + privileged_context_column: str = field( + default="privileged_context", + metadata={"help": "Column containing precomputed privileged context for SDFT."}, + ) + golden_response_column: str = field( + default="golden_response", + metadata={"help": "Column containing demonstration responses used to build privileged context."}, + ) + eval_num_prompts: int | None = field( + default=8, + metadata={"help": "Number of prompts to log during evaluation. Set to 0 to disable completion logging."}, + ) + demonstration_template: str = field( + default=DEFAULT_DEMONSTRATION_TEMPLATE.template, + metadata={"help": "Template used to build privileged context from demonstration content."}, + ) + tool_eval_num_examples: int | None = field( + default=None, + metadata={ + "help": "Optional number of eval examples to score for tool-use metrics. Defaults to the full eval split." + }, + ) + tool_eval_max_new_tokens: int = field( + default=256, + metadata={"help": "Maximum completion length for task evaluation generation."}, + ) + + +@dataclass +class ExampleSDFTConfig(SDFTConfig): + scale_rewards: str = field( + default="group", + metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, + ) + + +def _extract_prompt_text(prompt: Any) -> str: + if isinstance(prompt, str): + return prompt + if isinstance(prompt, list) and prompt and isinstance(prompt[0], dict): + for message in reversed(prompt): + if message.get("role") == "user": + content = message.get("content", "") + if isinstance(content, list): + return " ".join(part.get("text", "") for part in content if part.get("type") == "text") + return content + return str(prompt) + + +def _stringify_golden_response(response: Any) -> str: + if isinstance(response, str): + return response + if isinstance(response, list): + return "\n".join(_stringify_golden_response(item) for item in response) + return str(response) + + +def _build_privileged_context( + example: dict[str, Any], privileged_context_column: str, golden_response_column: str, template: Template +): + if privileged_context_column in example and example[privileged_context_column] is not None: + privileged_context = example[privileged_context_column] + elif golden_response_column in example: + privileged_context = template.safe_substitute( + orig_content=_extract_prompt_text(example["prompt"]), + output_text=_stringify_golden_response(example[golden_response_column]), + ) + elif "teacher_prompt" in example: + raise ValueError( + "Datasets for `trl.experimental.sdft` should provide `privileged_context` or `golden_response`, not " + "`teacher_prompt`." + ) + else: + raise ValueError("Dataset must contain either `privileged_context` or `golden_response` alongside `prompt`.") + + return { + "prompt": example["prompt"], + "privileged_context": privileged_context, + } + + +def _prepare_split(dataset, script_args: SDFTScriptArguments): + template = Template(script_args.demonstration_template) + return dataset.map( + lambda example: _build_privileged_context( + example, + privileged_context_column=script_args.privileged_context_column, + golden_response_column=script_args.golden_response_column, + template=template, + ), + remove_columns=dataset.column_names, + ) + + +def _can_prepare_privileged_context(dataset) -> bool: + columns = set(dataset.column_names) + return "prompt" in columns and ("privileged_context" in columns or "golden_response" in columns) + + +def _extract_action_and_input(text: str) -> tuple[str | None, str | None]: + action_match = re.search(r"Action:\s*([^\n]+)", text) + action_input_match = re.search(r"Action Input:\s*(.*)", text, flags=re.DOTALL) + action = action_match.group(1).strip() if action_match else None + action_input = action_input_match.group(1).strip() if action_input_match else None + return action, action_input + + +def _parse_json_object(text: str | None) -> tuple[bool, Any]: + if text is None: + return False, None + text = text.strip() + if text.startswith("```"): + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + try: + return True, json.loads(text) + except Exception: + return False, None + + +def _normalize_gold_answer(example: dict[str, Any]) -> tuple[str | None, Any]: + answers = example.get("golden_answer") or [] + if not answers: + return None, None + answer = answers[0] + action = answer.get("Action") + valid_json, action_input = _parse_json_object(answer.get("Action_Input")) + return action, action_input if valid_json else answer.get("Action_Input") + + +def _apply_prompt_template(tokenizer, prompt: Any) -> str: + return maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] + + +def _run_tooluse_eval( + trainer: SDFTTrainer, + eval_dataset, + max_new_tokens: int, + num_examples: int | None = None, + metric_prefix: str = "tool_eval", +) -> dict[str, float]: + if num_examples is not None: + eval_dataset = eval_dataset.select(range(min(num_examples, len(eval_dataset)))) + + prompts = eval_dataset["prompt"] + prompt_texts = [_apply_prompt_template(trainer.processing_class, prompt) for prompt in prompts] + tokenized = trainer.processing_class( + text=prompt_texts, + return_tensors="pt", + padding=True, + padding_side="left", + truncation=True, + max_length=trainer.max_prompt_length, + add_special_tokens=False, + ) + tokenized = {key: value.to(trainer.accelerator.device) for key, value in tokenized.items()} + + with ( + unwrap_model_for_generation( + trainer.model_wrapped, + trainer.accelerator, + gather_deepspeed3_params=trainer.args.ds3_gather_for_generation, + ) as unwrapped_model, + torch.no_grad(), + ): + generated = unwrapped_model.generate( + **tokenized, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=trainer.processing_class.pad_token_id, + eos_token_id=trainer.processing_class.eos_token_id, + ) + + prompt_length = tokenized["input_ids"].shape[1] + completions = trainer.processing_class.batch_decode(generated[:, prompt_length:], skip_special_tokens=True) + + action_correct = 0 + json_valid = 0 + full_match = 0 + parsed_action_present = 0 + records = [] + + for example, completion in zip(eval_dataset, completions, strict=True): + pred_action, pred_action_input_text = _extract_action_and_input(completion) + if pred_action is not None: + parsed_action_present += 1 + pred_json_valid, pred_action_input = _parse_json_object(pred_action_input_text) + if pred_json_valid: + json_valid += 1 + + gold_action, gold_action_input = _normalize_gold_answer(example) + is_action_correct = pred_action == gold_action and gold_action is not None + if is_action_correct: + action_correct += 1 + is_full_match = is_action_correct and pred_json_valid and pred_action_input == gold_action_input + if is_full_match: + full_match += 1 + + records.append( + { + "prompt": _extract_prompt_text(example["prompt"]), + "completion": completion, + "pred_action": pred_action, + "pred_action_input_text": pred_action_input_text, + "gold_action": gold_action, + "gold_action_input": gold_action_input, + "action_correct": is_action_correct, + "json_valid": pred_json_valid, + "full_match": is_full_match, + } + ) + + total = max(len(eval_dataset), 1) + metrics = { + f"{metric_prefix}/action_present_rate": parsed_action_present / total, + f"{metric_prefix}/valid_json_rate": json_valid / total, + f"{metric_prefix}/action_accuracy": action_correct / total, + f"{metric_prefix}/tool_call_accuracy": full_match / total, + } + + sample_path = os.path.join(trainer.args.output_dir, f"{metric_prefix}_samples.json") + os.makedirs(trainer.args.output_dir, exist_ok=True) + with open(sample_path, "w") as f: + json.dump(records[: min(20, len(records))], f, indent=2) + + return metrics + + +if __name__ == "__main__": + parser = TrlParser((SDFTScriptArguments, ExampleSDFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + if model_args.model_name_or_path is None: + raise ValueError("`model_name_or_path` is required.") + if script_args.ref_model_name_or_path is None and not model_args.use_peft: + script_args.ref_model_name_or_path = model_args.model_name_or_path + + if model_args.dtype in ["auto", None]: + if training_args.bf16: + dtype = torch.bfloat16 + elif training_args.fp16: + dtype = torch.float16 + else: + dtype = "auto" + else: + dtype = getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + training_args.model_init_kwargs = model_kwargs + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if script_args.dataset_path is not None: + dataset = load_from_disk(script_args.dataset_path) + else: + dataset = load_dataset( + script_args.dataset_name, + name=script_args.dataset_config, + streaming=script_args.dataset_streaming, + ) + + if not isinstance(dataset, DatasetDict): + raise ValueError("SDFT example expects a dataset with named splits.") + + train_dataset = _prepare_split(dataset[script_args.dataset_train_split], script_args) + raw_eval_dataset = dataset[script_args.dataset_test_split] if script_args.dataset_test_split in dataset else None + eval_dataset = None + if ( + training_args.eval_strategy != "no" + and raw_eval_dataset is not None + and _can_prepare_privileged_context(raw_eval_dataset) + ): + eval_dataset = _prepare_split(raw_eval_dataset, script_args) + + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) + ref_model = None + if script_args.ref_model_name_or_path is not None: + ref_model = AutoModelForCausalLM.from_pretrained(script_args.ref_model_name_or_path, **model_kwargs) + model.config.use_cache = False if training_args.gradient_checkpointing else True + if ref_model is not None: + ref_model.config.use_cache = True + + trainer = SDFTTrainer( + model=model, + ref_model=ref_model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + if eval_dataset is not None and script_args.eval_num_prompts: + generation_config = GenerationConfig( + max_new_tokens=training_args.max_completion_length, + do_sample=True, + temperature=training_args.temperature, + ) + trainer.add_callback( + LogCompletionsCallback(trainer, generation_config, num_prompts=script_args.eval_num_prompts) + ) + + pretrain_metrics = None + if raw_eval_dataset is not None and "golden_answer" in raw_eval_dataset.column_names: + pretrain_metrics = _run_tooluse_eval( + trainer, + raw_eval_dataset, + max_new_tokens=script_args.tool_eval_max_new_tokens, + num_examples=script_args.tool_eval_num_examples, + metric_prefix="tool_eval_before", + ) + trainer.log(pretrain_metrics) + trainer.log_metrics("eval", pretrain_metrics) + trainer.save_metrics("eval", pretrain_metrics) + + trainer.train() + + trainer.save_model(training_args.output_dir) + if eval_dataset is not None: + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if raw_eval_dataset is not None and "golden_answer" in raw_eval_dataset.column_names: + post_metrics = _run_tooluse_eval( + trainer, + raw_eval_dataset, + max_new_tokens=script_args.tool_eval_max_new_tokens, + num_examples=script_args.tool_eval_num_examples, + metric_prefix="tool_eval_after", + ) + if pretrain_metrics is not None: + for key, value in pretrain_metrics.items(): + after_key = key.replace("tool_eval_before/", "tool_eval_after/") + if after_key in post_metrics: + delta_name = after_key.replace("tool_eval_after/", "tool_eval_delta/") + post_metrics[delta_name] = post_metrics[after_key] - value + trainer.log(post_metrics) + trainer.log_metrics("eval", post_metrics) + trainer.save_metrics("eval", post_metrics) + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name or script_args.dataset_path) diff --git a/trl/experimental/sdft/sdft_config.py b/trl/experimental/sdft/sdft_config.py new file mode 100644 index 00000000000..84227e43cbf --- /dev/null +++ b/trl/experimental/sdft/sdft_config.py @@ -0,0 +1,68 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..self_distillation.self_distillation_config import SelfDistillationConfig + + +@dataclass +class SDFTConfig(SelfDistillationConfig): + r""" + Configuration class for [`SDFTTrainer`]. + + This adapts the official SDFT implementation to the TRL trainer API while reusing the common self-distillation + configuration shared with SDPO. + + Parameters: + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the student and teacher models. + generate_from_teacher (`bool`, *optional*, defaults to `False`): + Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. + teacher_prompt_template (`str`, *optional*, defaults to `"{prompt}\n\n{privileged_context}"`): + Template used to combine the student prompt and privileged context into the teacher prompt. + num_loss_tokens_to_skip (`int`, *optional*, defaults to `0`): + Number of initial completion tokens to exclude from the distillation loss. + """ + + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the student and teacher models."}, + ) + generate_from_teacher: bool = field( + default=False, + metadata={"help": "Whether on-policy generation should use the teacher-conditioned prompt."}, + ) + teacher_prompt_template: str = field( + default="{prompt}\n\n{privileged_context}", + metadata={ + "help": "Template used to combine the student prompt and privileged context into the teacher prompt." + }, + ) + num_loss_tokens_to_skip: int = field( + default=0, + metadata={"help": "Number of initial completion tokens to exclude from the distillation loss."}, + ) + + def __post_init__(self): + super().__post_init__() + if ( + "{prompt}" not in self.teacher_prompt_template + or "{privileged_context}" not in self.teacher_prompt_template + ): + raise ValueError( + "teacher_prompt_template must contain both `{prompt}` and `{privileged_context}` placeholders" + ) + if self.num_loss_tokens_to_skip < 0: + raise ValueError("num_loss_tokens_to_skip must be non-negative") diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py new file mode 100644 index 00000000000..010089e7ac3 --- /dev/null +++ b/trl/experimental/sdft/sdft_trainer.py @@ -0,0 +1,490 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import inspect +import textwrap +from collections import defaultdict +from functools import partial +from typing import Any + +import datasets +import torch +from accelerate.logging import get_logger +from accelerate.utils import is_peft_model +from datasets import Dataset, IterableDataset +from torch import nn +from torch.utils.data import DataLoader, Sampler +from transformers import ( + AutoProcessor, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available, is_peft_available + +from ...models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.callbacks import SyncRefModelCallback +from ...trainer.utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + get_config_model_id, + identity, + pad, + split_tensor_dict, + use_adapter, +) +from ..self_distillation.self_distillation_mixin import SelfDistillationMixin +from ..self_distillation.teacher_context import PromptTokenizer, extract_last_user_text +from ..utils import prepare_peft_model +from .sdft_config import SDFTConfig + + +if is_peft_available(): + from peft import PeftConfig + from peft.peft_model import PeftModel + + from ..self_distillation.peft_adapter_ema_callback import PEFTAdapterEMACallback + + +logger = get_logger(__name__) + + +class DemonstrationTeacherContextBuilder: + """Builds student and teacher contexts from prompts plus privileged context, as in SDFT.""" + + def __init__(self, trainer): + self.trainer = trainer + self.prompt_tokenizer = PromptTokenizer(trainer) + + def _stringify_privileged_context(self, privileged_context: Any) -> str: + if privileged_context is None: + raise ValueError( + "`privileged_context` must not be None for self-distillation teacher prompt construction." + ) + if isinstance(privileged_context, str): + return privileged_context + if isinstance(privileged_context, list) and privileged_context and isinstance(privileged_context[0], dict): + chunks = [] + for message in privileged_context: + content = message.get("content", "") + if isinstance(content, list): + text = " ".join(part.get("text", "") for part in content if part.get("type") == "text") + else: + text = str(content) + if text: + chunks.append(text) + return "\n".join(chunks) + return str(privileged_context) + + def _compose_teacher_prompt(self, prompt: Any, privileged_context: Any) -> Any: + privileged_text = self._stringify_privileged_context(privileged_context) + if isinstance(prompt, list): + system_messages = prompt[:-1] + prompt_text = extract_last_user_text(prompt) + teacher_text = self.trainer.args.teacher_prompt_template.format( + prompt=prompt_text, + privileged_context=privileged_text, + ) + return system_messages + [{"role": "user", "content": teacher_text}] + return self.trainer.args.teacher_prompt_template.format(prompt=prompt, privileged_context=privileged_text) + + def select_generation_prompts(self, prompts: list[Any], privileged_contexts: list[Any]) -> list[Any]: + if not self.trainer.generate_from_teacher: + return prompts + return [ + self._compose_teacher_prompt(prompt, privileged_context) + for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) + ] + + def build( + self, + prompts: list[Any], + privileged_contexts: list[Any], + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + ) -> dict[str, torch.Tensor]: + student_batch = self.prompt_tokenizer.tokenize_prompts(prompts) + teacher_prompts = [ + self._compose_teacher_prompt(prompt, privileged_context) + for prompt, privileged_context in zip(prompts, privileged_contexts, strict=True) + ] + teacher_batch = self.prompt_tokenizer.tokenize_prompts(teacher_prompts) + teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + return { + "prompt_ids": student_batch.prompt_ids, + "prompt_mask": student_batch.prompt_mask, + "teacher_input_ids": teacher_input_ids, + "teacher_attention_mask": teacher_attention_mask, + } + + +class SDFTTrainer(SelfDistillationMixin, _BaseTrainer): + """Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.""" + + _tag_names = ["trl", "sdft"] + _name = "SDFT" + config_cls = SDFTConfig + # docstyle-ignore + _paper = { + "title": "Self-Training with On-Policy Self-Distillation for Language Model Alignment", + "id": "2601.19897", + "citation": textwrap.dedent("""\ + @article{hubotter2026selftraining, + title = {{Self-Training with On-Policy Self-Distillation for Language Model Alignment}}, + author = {Jonas H\\"ubotter and Frederike L\\"ubeck and Lejs Behric and Anton Baumann and Marco Bagatella and Daniel Marta and Ido Hakimi and Idan Shenfeld and Thomas Kleine Buening and Carlos Guestrin and Andreas Krause}, + year = 2026, + eprint = {arXiv:2601.19897} + }"""), + } + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + args: SDFTConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: PeftConfig | None = None, + ): + if train_dataset is None: + raise ValueError("`train_dataset` is required") + if isinstance(train_dataset, IterableDataset): + raise NotImplementedError("Iterable datasets are not yet supported in SDFTTrainer.") + if isinstance(eval_dataset, IterableDataset) or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ): + raise NotImplementedError("Iterable eval datasets are not yet supported in SDFTTrainer.") + if args.use_vllm: + raise NotImplementedError("SDFTTrainer does not support `use_vllm=True` yet.") + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + elif args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to `SDFTConfig`, but `model` is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to SDFTTrainer. Pass either a base " + "model with `peft_config`, or a pre-wrapped PEFT model." + ) + if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): + model = prepare_peft_model(model, peft_config, args) + + if processing_class is None: + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) + + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.num_generations = args.num_generations + self.num_iterations = args.num_iterations + self.temperature = args.temperature + self.loss_type = args.loss_type + self.shuffle_dataset = args.shuffle_dataset + self.generate_from_teacher = args.generate_from_teacher + self.num_loss_tokens_to_skip = args.num_loss_tokens_to_skip + self.chat_template_kwargs = args.chat_template_kwargs or {} + self._step = 0 + self._buffered_inputs = None + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self.prompt_tokenizer = PromptTokenizer(self) + self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) + + generation_kwargs = { + "max_new_tokens": self.max_completion_length, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "repetition_penalty": args.repetition_penalty, + "cache_implementation": args.cache_implementation, + } + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**generation_kwargs) + + if hasattr(model, "warnings_issued"): + model.warnings_issued["estimate_tokens"] = True + + super().__init__( + model=model, + args=args, + data_collator=identity, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + compute_loss_func="non-None value to disable scaling", + ) + + if args.disable_dropout: + disable_dropout_in_model(self.model) + + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + # In self-distillation the teacher is always derived from the student: + # - PEFT: base model with adapter disabled (or EMA teacher adapter when sync_ref_model=True) + # - Non-PEFT: same model (or deep-copied EMA model when sync_ref_model=True) + self.teacher_model = None + + if args.sync_ref_model: + if is_peft_available() and is_peft_model(self.model): + self.add_callback( + PEFTAdapterEMACallback( + model=self.model, + teacher_adapter_name="teacher", + update_rate=args.ref_model_mixup_alpha, + sync_steps=args.ref_model_sync_steps, + accelerator=self.accelerator, + ) + ) + else: + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(self.teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + self.teacher_model = prepare_fsdp(self.teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(self.teacher_model, evaluation_mode=True) + self.add_callback(SyncRefModelCallback(ref_model=self.teacher_model, accelerator=self.accelerator)) + + self.model_accepts_loss_kwargs = False + + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_train_sampler(self, dataset=None) -> Sampler: + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=self.num_generations, + seed=self.args.seed, + ) + + def training_step(self, model, inputs, num_items_in_batch): + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + return output + + def _prepare_inputs(self, generation_batch): + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + generation_batch = self._build_buffered_batch(generation_batch) + self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) + self._dispatch_self_distillation_callback( + "on_generation_batch_built", + generate_every=generate_every, + steps_per_generation=self.args.steps_per_generation, + ) + return self._buffered_inputs[self._step % self.args.steps_per_generation] + return self._build_buffered_batch(generation_batch) + + def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, torch.Tensor]: + generate_inputs = self.processing_class( + text=self.prompt_tokenizer.apply_prompt_template(prompts), + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + # This generation helper builds tokenized model inputs directly, so use the base Trainer tensor preparation + # instead of re-entering the buffered outer training hook. + generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) + + with ( + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + ) as unwrapped_model, + torch.no_grad(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, + generation_config=self.generation_config, + disable_compile=True, + ) + + prompt_length = generate_inputs["input_ids"].size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) + completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).long() + + completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] + completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + return ( + pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), + pad(completion_mask, padding_value=0, padding_side="right"), + ) + + def _build_buffered_batch(self, inputs: list[dict[str, Any]]) -> dict[str, torch.Tensor | Any]: + prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + generation_prompts = self.teacher_context_builder.select_generation_prompts(prompts, privileged_contexts) + generation_prompt_text = self.prompt_tokenizer.apply_prompt_template(generation_prompts) + self._dispatch_self_distillation_callback( + "on_generation_prompts_selected", + generation_prompts=generation_prompts, + generation_prompt_text=generation_prompt_text, + ) + completion_ids, completion_mask = self._generate_completion_ids(generation_prompts) + + teacher_batch = self.teacher_context_builder.build( + prompts, privileged_contexts, completion_ids, completion_mask + ) + + prompt_completion_ids = torch.cat([teacher_batch["prompt_ids"], completion_ids], dim=1) + attention_mask = torch.cat([teacher_batch["prompt_mask"], completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + + with torch.no_grad(): + generate_every = self.args.steps_per_generation * self.num_iterations + if not self.generate_from_teacher and self.args.gradient_accumulation_steps % generate_every != 0: + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ) + else: + old_per_token_logps = None + + self._dispatch_self_distillation_callback( + "on_self_distillation_batch_prepared", + old_per_token_logps=old_per_token_logps, + prompt_ids=teacher_batch["prompt_ids"], + completion_ids=completion_ids, + ) + output = { + "prompt_ids": teacher_batch["prompt_ids"], + "prompt_mask": teacher_batch["prompt_mask"], + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "teacher_input_ids": teacher_batch["teacher_input_ids"], + "teacher_attention_mask": teacher_batch["teacher_attention_mask"], + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + return output + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The SDFTTrainer does not support returning outputs") + + if self.num_loss_tokens_to_skip > 0: + inputs = dict(inputs) + completion_mask = inputs["completion_mask"].clone() + token_positions = torch.arange(completion_mask.size(1), device=completion_mask.device).unsqueeze(0) + completion_mask = completion_mask * (token_positions >= self.num_loss_tokens_to_skip).long() + inputs["completion_mask"] = completion_mask + + loss = self._compute_self_distillation_loss(model, inputs) + accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + return loss / accumulation_scale + + def _get_teacher_context_for_self_distillation(self, model): + if is_peft_available() and isinstance(self.model, PeftModel): + model = self.accelerator.unwrap_model(self.model) + if self.args.sync_ref_model and "teacher" in model.peft_config: + return use_adapter(model, adapter_name="teacher") + return use_adapter(model, adapter_name=None) + return super()._get_teacher_context_for_self_distillation(model) diff --git a/trl/experimental/sdpo/__init__.py b/trl/experimental/sdpo/__init__.py new file mode 100644 index 00000000000..f50a54cf7c8 --- /dev/null +++ b/trl/experimental/sdpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .sdpo_config import SDPOConfig +from .sdpo_trainer import SDPOTrainer + + +__all__ = ["SDPOConfig", "SDPOTrainer"] diff --git a/trl/experimental/sdpo/sdpo.py b/trl/experimental/sdpo/sdpo.py new file mode 100644 index 00000000000..6723b7b5919 --- /dev/null +++ b/trl/experimental/sdpo/sdpo.py @@ -0,0 +1,394 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "math-verify", +# "latex2sympy2_extended", +# "trackio", +# "kernels", +# ] +# /// + +""" +Usage: + +```bash +python trl/experimental/sdpo/sdpo.py \ + --model_name_or_path Qwen/Qwen2.5-Math-1.5B-Instruct \ + --dataset_name openai/gsm8k \ + --dataset_config main \ + --output_dir outputs/sdpo-qwen35-2b-gsm8k \ + --learning_rate 5e-5 \ + --dtype bfloat16 \ + --bf16 true \ + --max_completion_length 128 \ + --use_peft \ + --lora_target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 2 \ + --num_generations 8 \ + --generation_batch_size 32 \ + --distillation_alpha 1.0 \ + --full_logit_distillation false \ + --sdpo_policy_loss_mode hybrid \ + --report_to none \ + --eval_strategy steps \ + --eval_steps 1000 \ + --save_strategy no \ + --eval_num_prompts 0 \ + --accuracy_eval_num_examples 64 \ + --max_train_examples 256 \ + --max_eval_examples 128 +``` + +This example uses verifiable math rewards and reports answer accuracy before and after training. If your dataset +already contains textual environment feedback, pass the column name via `--feedback_column`; it will be forwarded as +`privileged_context` for SDPO reprompting. +""" + +import os +import re +from dataclasses import dataclass, field +from typing import Any + +import torch +from datasets import DatasetDict, load_dataset, load_from_disk +from transformers import AutoTokenizer, GenerationConfig + +from trl import ( + LogCompletionsCallback, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.data_utils import maybe_apply_chat_template +from trl.experimental.sdpo import SDPOConfig, SDPOTrainer + + +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +SYSTEM_PROMPT = ( + "A conversation between user and assistant. The user asks a question, and the assistant solves it. The assistant " + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " + "must be enclosed within tags, and the final answer must be on its own line in the format " + "`#### `." +) + + +@dataclass +class SDPOScriptArguments(ScriptArguments): + dataset_path: str | None = field( + default=None, + metadata={"help": "Optional local dataset path to load with `load_from_disk`. Overrides `dataset_name`."}, + ) + feedback_column: str | None = field( + default=None, + metadata={ + "help": "Optional dataset column containing textual environment feedback to pass as `privileged_context`." + }, + ) + eval_num_prompts: int | None = field( + default=8, + metadata={"help": "Number of prompts to log during evaluation. Set to 0 to disable completion logging."}, + ) + accuracy_eval_num_examples: int | None = field( + default=128, + metadata={"help": "Optional number of eval examples to score for answer accuracy. Defaults to 128."}, + ) + accuracy_eval_max_new_tokens: int = field( + default=128, + metadata={"help": "Maximum completion length for answer-accuracy evaluation generation."}, + ) + feedback_from_solution: str | None = field( + default=None, + metadata={ + "help": "Optional synthesized feedback source when the dataset has no feedback column. Supported: " + "`final_answer`, `full_solution`." + }, + ) + max_train_examples: int | None = field( + default=None, + metadata={"help": "Optional cap on the number of training examples loaded from the selected train split."}, + ) + max_eval_examples: int | None = field( + default=None, + metadata={"help": "Optional cap on the number of evaluation examples loaded from the selected eval split."}, + ) + dataset_shuffle_seed: int = field( + default=42, + metadata={"help": "Random seed used before applying `max_train_examples` or `max_eval_examples`."}, + ) + + +@dataclass +class ExampleSDPOConfig(SDPOConfig): + scale_rewards: str = field( + default="group", + metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, + ) + + +def _make_solution_feedback(final_answer: str, worked_solution: str, feedback_from_solution: str | None) -> str | None: + if feedback_from_solution is None: + return None + if feedback_from_solution == "final_answer": + return ( + "Your previous answer was incorrect. The correct final answer is:\n\n" + f"#### {final_answer}\n\n" + "Revise your reasoning and end with the same final answer format." + ) + if feedback_from_solution == "full_solution": + return ( + "Your previous answer was incorrect. Here is a correct worked solution:\n\n" + f"{worked_solution}\n\n" + "Use it to solve the original question correctly." + ) + raise ValueError("feedback_from_solution must be one of: `final_answer`, `full_solution`.") + + +def _make_conversation( + example: dict[str, Any], feedback_column: str | None, feedback_from_solution: str | None +) -> dict[str, Any]: + prompt = example.get("prompt") + if prompt is None and "problem" in example: + prompt = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": example["problem"]}, + ] + if prompt is None and "question" in example: + prompt = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": example["question"]}, + ] + + if prompt is None: + raise ValueError("Each example must provide one of: `prompt`, `problem`, or `question`.") + + output = {"prompt": prompt} + + solution = None + if "solution" in example: + solution = example["solution"] + elif "answer" in example: + solution = _normalize_gsm8k_answer(example["answer"]) + + if solution is not None: + output["solution"] = solution + + if feedback_column is not None and feedback_column in example: + output["privileged_context"] = example[feedback_column] + elif "privileged_context" in example: + output["privileged_context"] = example["privileged_context"] + elif solution is not None: + worked_solution = example.get("solution") + if worked_solution is None and "answer" in example: + worked_solution = example["answer"].strip() + if worked_solution is None: + worked_solution = f"#### {solution}" + synthesized_feedback = _make_solution_feedback(solution, worked_solution, feedback_from_solution) + if synthesized_feedback is not None: + output["privileged_context"] = synthesized_feedback + + return output + + +def _normalize_gsm8k_answer(answer_text: str) -> str: + if "####" not in answer_text: + return answer_text.strip() + return answer_text.split("####", 1)[1].strip().replace(",", "") + + +def _extract_predicted_answer(completion_text: str) -> str | None: + match = re.search(r"####\s*([^\n]+)", completion_text) + if match: + return match.group(1).strip().replace(",", "") + + matches = re.findall(r"(-?\$?[0-9][0-9,]*(?:\.[0-9]+)?)", completion_text) + if not matches: + return None + return matches[-1].replace("$", "").replace(",", "").strip() + + +def _gsm8k_accuracy_reward(completions, solution, **kwargs) -> list[float]: + rewards = [] + for completion, gold in zip(completions, solution, strict=True): + content = completion[0]["content"] if isinstance(completion, list) else completion + pred = _extract_predicted_answer(content) + rewards.append(1.0 if pred is not None and pred == gold else 0.0) + return rewards + + +def _gsm8k_soft_format_reward(completions, **kwargs) -> list[float]: + pattern = r".*?\s*####\s*[^\n]+" + rewards = [] + for completion in completions: + content = completion[0]["content"] if isinstance(completion, list) else completion + rewards.append(0.25 if re.match(pattern, content, flags=re.DOTALL) else 0.0) + return rewards + + +def _run_accuracy_eval( + trainer: SDPOTrainer, eval_dataset, max_new_tokens: int, num_examples: int | None, metric_prefix: str = "math_eval" +) -> dict[str, float]: + if num_examples is not None: + eval_dataset = eval_dataset.select(range(min(num_examples, len(eval_dataset)))) + + prompts = eval_dataset["prompt"] + prompt_texts = [ + maybe_apply_chat_template({"prompt": prompt}, trainer.processing_class)["prompt"] for prompt in prompts + ] + tokenized = trainer.processing_class( + text=prompt_texts, + return_tensors="pt", + padding=True, + padding_side="left", + truncation=True, + max_length=trainer.max_prompt_length, + add_special_tokens=False, + ) + tokenized = {key: value.to(trainer.accelerator.device) for key, value in tokenized.items()} + model = trainer.accelerator.unwrap_model(trainer.model) + was_training = model.training + model.eval() + with torch.no_grad(): + generated = model.generate( + **tokenized, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=trainer.processing_class.pad_token_id, + eos_token_id=trainer.processing_class.eos_token_id, + ) + if was_training: + model.train() + + prompt_length = tokenized["input_ids"].shape[1] + completions = trainer.processing_class.batch_decode(generated[:, prompt_length:], skip_special_tokens=True) + completion_messages = [[{"role": "assistant", "content": completion}] for completion in completions] + rewards = _gsm8k_accuracy_reward(completion_messages, solution=eval_dataset["solution"]) + total = max(len(rewards), 1) + return { + f"{metric_prefix}/accuracy": sum(rewards) / total, + f"{metric_prefix}/num_scored": float(len(rewards)), + } + + +if __name__ == "__main__": + parser = TrlParser((SDPOScriptArguments, ExampleSDPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + training_args.model_init_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + training_args.model_init_kwargs["device_map"] = get_kbit_device_map() + training_args.model_init_kwargs["quantization_config"] = quantization_config + + if script_args.dataset_path is not None: + dataset = load_from_disk(script_args.dataset_path) + else: + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + if not isinstance(dataset, DatasetDict): + raise ValueError("SDPO example expects a dataset with named splits.") + + train_split = dataset[script_args.dataset_train_split] + if script_args.max_train_examples is not None: + train_split = train_split.shuffle(seed=script_args.dataset_shuffle_seed).select( + range(min(script_args.max_train_examples, len(train_split))) + ) + + train_dataset = train_split.map( + lambda example: _make_conversation(example, script_args.feedback_column, script_args.feedback_from_solution), + remove_columns=train_split.column_names, + ) + eval_dataset = None + if training_args.eval_strategy != "no": + eval_split = dataset[script_args.dataset_test_split] + if script_args.max_eval_examples is not None: + eval_split = eval_split.shuffle(seed=script_args.dataset_shuffle_seed).select( + range(min(script_args.max_eval_examples, len(eval_split))) + ) + + eval_dataset = eval_split.map( + lambda example: _make_conversation( + example, script_args.feedback_column, script_args.feedback_from_solution + ), + remove_columns=eval_split.column_names, + ) + + reward_funcs = [_gsm8k_soft_format_reward, _gsm8k_accuracy_reward] + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + trainer = SDPOTrainer( + model=model_args.model_name_or_path, + args=training_args, + reward_funcs=reward_funcs, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=get_peft_config(model_args), + processing_class=tokenizer, + ) + + if eval_dataset is not None and script_args.eval_num_prompts: + generation_config = GenerationConfig( + max_new_tokens=training_args.max_completion_length, + do_sample=True, + temperature=training_args.temperature, + ) + trainer.add_callback( + LogCompletionsCallback(trainer, generation_config, num_prompts=script_args.eval_num_prompts) + ) + + if eval_dataset is not None: + pre_metrics = _run_accuracy_eval( + trainer, + eval_dataset, + max_new_tokens=script_args.accuracy_eval_max_new_tokens, + num_examples=script_args.accuracy_eval_num_examples, + ) + trainer.log_metrics("eval", {f"before_{k}": v for k, v in pre_metrics.items()}) + trainer.save_metrics("eval", {f"before_{k}": v for k, v in pre_metrics.items()}) + + trainer.train() + + trainer.save_model(training_args.output_dir) + if eval_dataset is not None: + post_metrics = _run_accuracy_eval( + trainer, + eval_dataset, + max_new_tokens=script_args.accuracy_eval_max_new_tokens, + num_examples=script_args.accuracy_eval_num_examples, + ) + after_metrics = {f"after_{k}": v for k, v in post_metrics.items()} + delta_metrics = { + f"delta_{k.split('/', 1)[1]}": after_metrics[f"after_{k}"] - pre_metrics[k] for k in pre_metrics + } + trainer.log_metrics("eval", after_metrics | delta_metrics) + trainer.save_metrics("eval", after_metrics | delta_metrics) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name or script_args.dataset_path) diff --git a/trl/experimental/sdpo/sdpo_config.py b/trl/experimental/sdpo/sdpo_config.py new file mode 100644 index 00000000000..1cc8c1510b7 --- /dev/null +++ b/trl/experimental/sdpo/sdpo_config.py @@ -0,0 +1,148 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..self_distillation import SelfDistillationConfig + + +@dataclass +class SDPOConfig(SelfDistillationConfig): + r""" + Configuration class for the [`SDPOTrainer`]. + + This class extends [`experimental.self_distillation.SelfDistillationConfig`] with the online teacher-construction + parameters used by Self-Distillation Policy Optimization (SDPO). + + Parameters: + > Parameters that control the SDPO loss + + sdpo_policy_loss_mode (`str`, *optional*, defaults to `"distillation_only"`): + How SDPO combines the online policy loss and self-distillation loss. Supported: `distillation_only`, + `hybrid`. + distillation_alpha (`float`, *optional*, defaults to `1.0`): + Divergence interpolation coefficient. Token-level SDPO requires the official reverse-KL setting + `distillation_alpha=1.0`. + distillation_topk (`int` or `None`, *optional*): + Top-k approximation for logit-level SDPO. Requires `full_logit_distillation=True`. + + > Parameters that control the teacher + + teacher_regularization (`str`, *optional*, defaults to `"ema"`): + Teacher update strategy. Supported: `ema`, `none`. + teacher_update_rate (`float` or `None`, *optional*): + EMA update rate used when `teacher_regularization="ema"`. + ema_update_rate (`float`, *optional*, defaults to `0.05`): + Deprecated alias for `teacher_update_rate`. + + > Parameters that control reprompting + + use_successful_as_teacher (`bool`, *optional*, defaults to `True`): + Whether successful rollouts are turned into teacher demonstrations. + success_reward_threshold (`float`, *optional*, defaults to `1.0`): + Minimum reward for a rollout to count as successful. + include_environment_feedback (`bool`, *optional*, defaults to `False`): + Whether `privileged_context` is injected into teacher reprompts when available. + """ + + dont_reprompt_on_self_success: bool = field( + default=True, + metadata={"help": "Skip reprompting when model generates correct response."}, + ) + distillation_alpha: float = field( + default=1.0, + metadata={ + "help": "KL divergence direction for SDPO. Token-level SDPO requires reverse KL (`distillation_alpha=1.0`)." + }, + ) + distillation_topk: int | None = field( + default=None, + metadata={"help": "Top-K approximation for logit-level SDPO. Requires `full_logit_distillation=True`."}, + ) + sdpo_policy_loss_mode: str = field( + default="distillation_only", + metadata={"help": "SDPO policy loss mode. Supported: `distillation_only`, `hybrid`."}, + ) + teacher_regularization: str = field( + default="ema", + metadata={"help": "Teacher regularization mode. Supported: `ema`, `none`."}, + ) + teacher_update_rate: float | None = field( + default=None, + metadata={"help": "Teacher update rate used for EMA teacher synchronization."}, + ) + ema_update_rate: float = field( + default=0.05, + metadata={"help": "Deprecated alias for `teacher_update_rate`."}, + ) + max_reprompt_len: int = field( + default=10240, + metadata={"help": "Maximum length for reprompting in self-distillation."}, + ) + use_successful_as_teacher: bool = field( + default=True, + metadata={"help": "Use successful rollouts as implicit feedback for self-distillation."}, + ) + success_reward_threshold: float = field( + default=1.0, + metadata={"help": "Minimum reward for a rollout to be considered a successful demonstration."}, + ) + reprompt_template: str = field( + default="{prompt}{solution}{feedback}\n\nCorrectly solve the original question.\n", + metadata={"help": "Template for reprompting the teacher with a successful demonstration."}, + ) + solution_template: str = field( + default="\nCorrect solution:\n\n{successful_previous_attempt}\n\n", + metadata={"help": "Template for formatting the successful demonstration text."}, + ) + feedback_template: str = field( + default="\nThe following is feedback from your unsuccessful earlier attempt:\n\n{feedback_raw}\n\n", + metadata={"help": "Template for formatting environment feedback for reprompting."}, + ) + include_environment_feedback: bool = field( + default=False, + metadata={"help": "Whether to include environment feedback in teacher reprompts when available."}, + ) + environment_feedback_only_without_solution: bool = field( + default=False, + metadata={"help": "Whether to use feedback only when no successful solution is available."}, + ) + remove_thinking_from_demonstration: bool = field( + default=False, + metadata={"help": "Whether to remove ... blocks from the demonstration text."}, + ) + + def __post_init__(self): + super().__post_init__() + + if self.teacher_update_rate is None: + self.teacher_update_rate = self.ema_update_rate + + if self.teacher_regularization not in {"ema", "none"}: + raise ValueError("teacher_regularization must be one of: 'ema', 'none'") + if not 0.0 <= self.teacher_update_rate <= 1.0: + raise ValueError("teacher_update_rate must be in [0, 1]") + if self.sdpo_policy_loss_mode not in {"distillation_only", "hybrid"}: + raise ValueError("sdpo_policy_loss_mode must be one of: 'distillation_only', 'hybrid'") + if self.sdpo_policy_loss_mode == "distillation_only" and self.distillation_weight <= 0: + raise ValueError("distillation_only mode requires `distillation_weight > 0`.") + if self.max_reprompt_len <= 0: + raise ValueError("max_reprompt_len must be positive") + if not self.full_logit_distillation and self.distillation_alpha != 1.0: + raise ValueError( + "SDPO token-level distillation requires `distillation_alpha=1.0`. " + "Set `full_logit_distillation=True` to use other divergence settings." + ) + if self.distillation_topk is not None and not self.full_logit_distillation: + raise ValueError("SDPO `distillation_topk` requires `full_logit_distillation=True`.") diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py new file mode 100644 index 00000000000..ef84a17a44c --- /dev/null +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -0,0 +1,387 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import re +import textwrap +from typing import Any + +import torch +from accelerate.utils import gather_object +from datasets import Dataset, IterableDataset +from torch import nn +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback + +from ...trainer.callbacks import SyncRefModelCallback +from ...trainer.utils import pad +from ..self_distillation.base_self_distillation_trainer import BaseSelfDistillationTrainer +from ..self_distillation.teacher_context import TokenizedPromptBatch, extract_last_user_text +from .sdpo_config import SDPOConfig + + +class EMATeacherSyncCallback(SyncRefModelCallback): + """Synchronize an EMA teacher model with the student model on each step.""" + + def __init__(self, teacher_model, update_rate: float, accelerator=None): + super().__init__(ref_model=teacher_model, accelerator=accelerator) + self.update_rate = update_rate + + def on_step_end(self, args, state, control, **kwargs): + model = kwargs["model"] + if self.accelerator is not None: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, self.update_rate) + + +class SuccessfulRolloutTeacherContextBuilder: + """Builds SDPO teacher contexts from successful rollouts, following the official online implementation.""" + + def __init__(self, trainer): + self.trainer = trainer + self.last_metrics: dict[str, float] = {} + + def _build_reprompt_text(self, prompt_text: str, solution_text: str, feedback_text: str) -> str: + return self.trainer.args.reprompt_template.format( + prompt=prompt_text, + solution=solution_text, + feedback=feedback_text, + ) + + def _tokenize_teacher_messages( + self, teacher_messages_list: list[str | list[dict[str, Any]]] + ) -> TokenizedPromptBatch: + teacher_prompt_ids_list = [] + device = self.trainer.accelerator.device + chat_template_kwargs = getattr(self.trainer, "chat_template_kwargs", {}) + for msg in teacher_messages_list: + if isinstance(msg, list) and isinstance(msg[0], dict): + tokenized = self.trainer.processing_class.apply_chat_template( + msg, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + **chat_template_kwargs, + ) + if isinstance(tokenized, torch.Tensor): + ids = tokenized.squeeze(0) + else: + ids = tokenized["input_ids"].squeeze(0) + else: + ids = self.trainer.processing_class.encode(msg, return_tensors="pt").squeeze(0) + + if ids.shape[0] > self.trainer.args.max_reprompt_len: + ids = ids[-self.trainer.args.max_reprompt_len :] + teacher_prompt_ids_list.append(ids) + + teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] + teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] + return TokenizedPromptBatch( + prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), + ) + + def build( + self, + output: dict[str, torch.Tensor | Any], + prompts: list[Any], + rewards: torch.Tensor, + feedbacks: list[Any] | None = None, + ) -> dict[str, torch.Tensor]: + device = self.trainer.accelerator.device + mode = "train" if self.trainer.model.training else "eval" + num_generations = self.trainer.num_generations if mode == "train" else self.trainer.num_generations_eval + completion_ids = output["completion_ids"] + completion_mask = output["completion_mask"] + + num_local = len(prompts) + process_start = self.trainer.accelerator.process_index * num_local + process_slice = slice(process_start, process_start + num_local) + + # Rewards arrive already locally sliced (per-process) from the rollout mixin; re-gather them so + # the mining loop can find successful rollouts across all processes within each generation group. + all_rewards = self.trainer.accelerator.gather(rewards) + # Completion tensors are padded to the local max length per rank; align shapes before gathering. + # Use separate variables so the original completion_ids/completion_mask stay unpadded for the + # teacher concat (they must match the student's sequence length for logits_to_keep alignment). + padded_completion_ids = self.trainer.accelerator.pad_across_processes( + completion_ids, dim=1, pad_index=self.trainer.pad_token_id + ) + all_completion_ids = self.trainer.accelerator.gather(padded_completion_ids) + all_prompts = gather_object(prompts) + total_samples = all_rewards.shape[0] + all_feedbacks = gather_object(feedbacks) if feedbacks is not None else [None] * total_samples + + threshold = self.trainer.args.success_reward_threshold + dont_reprompt_self = self.trainer.args.dont_reprompt_on_self_success + feedback_only_without_solution = self.trainer.args.environment_feedback_only_without_solution + self_distillation_mask = torch.zeros(total_samples, device=device) + num_with_solution = 0 + num_with_feedback_available = 0 + num_with_feedback_used = 0 + success_group_count = 0 + successful_demo_indices: list[int | None] = [None] * total_samples + use_feedback_flags: list[bool] = [False] * total_samples + has_solution_flags: list[bool] = [False] * total_samples + + for i in range(total_samples): + group_start = (i // num_generations) * num_generations + group_end = group_start + num_generations + + successful = [] + if self.trainer.args.use_successful_as_teacher: + for j in range(group_start, group_end): + if dont_reprompt_self and j == i: + continue + if all_rewards[j].item() >= threshold: + successful.append(j) + + if i % num_generations == 0: + # Count groups with any successful rollout, ignoring self-exclusion which only + # affects per-sample teacher assignment, not whether the group has successes. + group_has_success = any(all_rewards[j].item() >= threshold for j in range(group_start, group_end)) + if group_has_success: + success_group_count += 1 + + raw_feedback = all_feedbacks[i] + has_feedback = isinstance(raw_feedback, str) and raw_feedback.strip() != "" + if has_feedback: + num_with_feedback_available += 1 + + has_solution = len(successful) > 0 + has_solution_flags[i] = has_solution + if has_solution: + successful_demo_indices[i] = successful[0] + use_feedback = ( + self.trainer.args.include_environment_feedback + and has_feedback + and (not feedback_only_without_solution or not has_solution) + ) + use_feedback_flags[i] = use_feedback + if use_feedback: + num_with_feedback_used += 1 + if has_solution or use_feedback: + self_distillation_mask[i] = 1.0 + if has_solution: + num_with_solution += 1 + + local_teacher_messages = [] + local_self_distillation_mask = self_distillation_mask[process_slice] + for global_idx in range(process_start, process_start + num_local): + original_prompt = all_prompts[global_idx] + raw_feedback = all_feedbacks[global_idx] + has_solution = has_solution_flags[global_idx] + use_feedback = use_feedback_flags[global_idx] + + if not has_solution and not use_feedback: + local_teacher_messages.append(original_prompt) + continue + + solution_text = "" + if has_solution: + demo_idx = successful_demo_indices[global_idx] + if demo_idx is None: + raise RuntimeError("Expected a successful demonstration index for an active SDPO teacher prompt.") + demo_ids = all_completion_ids[demo_idx] + demo_ids = demo_ids[demo_ids != self.trainer.processing_class.pad_token_id] + demo_text = self.trainer.processing_class.decode(demo_ids, skip_special_tokens=True) + + if self.trainer.args.remove_thinking_from_demonstration: + demo_text = re.sub(r".*?", "", demo_text, flags=re.DOTALL).strip() + + solution_text = self.trainer.args.solution_template.format(successful_previous_attempt=demo_text) + + feedback_text = "" + if use_feedback: + feedback_text = self.trainer.args.feedback_template.format(feedback_raw=raw_feedback) + + if isinstance(original_prompt, list): + system_messages = original_prompt[:-1] + prompt_text = extract_last_user_text(original_prompt) + reprompt_text = self._build_reprompt_text(prompt_text, solution_text, feedback_text) + local_teacher_messages.append(system_messages + [{"role": "user", "content": reprompt_text}]) + else: + local_teacher_messages.append(self._build_reprompt_text(original_prompt, solution_text, feedback_text)) + + teacher_batch = self._tokenize_teacher_messages(local_teacher_messages) + teacher_input_ids = torch.cat([teacher_batch.prompt_ids, completion_ids], dim=1) + teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, completion_mask], dim=1) + + batch_size = total_samples if total_samples > 0 else 1 + num_groups = max(1, total_samples // max(1, num_generations)) + self.last_metrics = { + "self_distillation/success_group_fraction": success_group_count / num_groups, + "self_distillation/success_sample_fraction": num_with_solution / batch_size, + "self_distillation/feedback_available_fraction": num_with_feedback_available / batch_size, + "self_distillation/feedback_used_fraction": num_with_feedback_used / batch_size, + "self_distillation/reprompt_sample_fraction": self_distillation_mask.float().mean().item(), + } + + return { + "teacher_input_ids": teacher_input_ids, + "teacher_attention_mask": teacher_attention_mask, + "self_distillation_mask": local_self_distillation_mask, + } + + +class SDPOTrainer(BaseSelfDistillationTrainer): + """ + Trainer for Self-Distillation Policy Optimization (SDPO). + + SDPO augments on-policy optimization with self-distillation from the model's own high-reward trajectories. It + converts tokenized feedback into a dense learning signal without any external teacher or explicit reward model. + SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed + next-token predictions back into the policy. + """ + + config_cls = SDPOConfig + _tag_names = ["trl", "sdpo"] + _name = "SDPO" + # docstyle-ignore + _paper = { + "title": "Reinforcement Learning via Self-Distillation", + "id": "2601.20802", + "citation": textwrap.dedent("""\ + @article{hubotter2026sdpo, + title = {{Reinforcement Learning via Self-Distillation}}, + author = {Jonas H\\"ubotter and Frederike L\\"ubeck and Lejs Behric and Anton Baumann and Marco Bagatella and Daniel Marta and Ido Hakimi and Idan Shenfeld and Thomas Kleine Buening and Carlos Guestrin and Andreas Krause}, + year = 2026, + eprint = {arXiv:2601.20802} + }"""), + } + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + reward_funcs: Any | list[Any] | None = None, + args: SDPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config=None, + ): + if reward_funcs is None or (isinstance(reward_funcs, list) and len(reward_funcs) == 0): + raise ValueError("`reward_funcs` is required for SDPOTrainer because SDPO must score rollouts.") + super().__init__( + model=model, + reward_funcs=reward_funcs, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + self.teacher_context_builder = SuccessfulRolloutTeacherContextBuilder(self) + if self.args.teacher_regularization == "ema": + # `self.model` may already be accelerator-wrapped after the shared base constructor. Build the EMA + # teacher from the unwrapped student model first, then prepare it as an auxiliary eval-only module. + student_model = self.accelerator.unwrap_model(self.model) + self.teacher_model = copy.deepcopy(student_model) + self.teacher_model.requires_grad_(False) + self.teacher_model.eval() + self.teacher_model = self._prepare_auxiliary_model_for_eval(self.teacher_model) + self.add_callback( + EMATeacherSyncCallback( + teacher_model=self.teacher_model, + update_rate=self.args.teacher_update_rate, + accelerator=self.accelerator, + ) + ) + + def _allow_topk_without_full_logit_distillation(self) -> bool: + return False + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + prompts, privileged_contexts = self._split_prompt_and_privileged_context(inputs) + + output = super()._generate_and_score_completions(inputs) + output.update( + self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=privileged_contexts) + ) + + mode = "train" if self.model.training else "eval" + for key, value in self.teacher_context_builder.last_metrics.items(): + self._metrics[mode][key].append(value) + self._warn_on_inactive_self_distillation(mode) + + self._dispatch_self_distillation_callback( + "on_teacher_context_built", + teacher_input_ids=output["teacher_input_ids"], + teacher_attention_mask=output["teacher_attention_mask"], + completion_mask=output["completion_mask"], + self_distillation_mask=output["self_distillation_mask"], + ) + + return output + + def _warn_on_inactive_self_distillation(self, mode: str) -> None: + metrics = self.teacher_context_builder.last_metrics + tolerance = self.args.diagnostics_flat_tolerance + + reprompt_fraction = metrics.get("self_distillation/reprompt_sample_fraction", 0.0) + success_fraction = metrics.get("self_distillation/success_group_fraction", 0.0) + + if reprompt_fraction <= tolerance: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="inactive_self_distillation", + message=( + "SDPO self-distillation is inactive because no reprompted samples were constructed. " + "This usually means no rollout exceeded `success_reward_threshold` and no usable privileged " + "feedback was available." + ), + ) + else: + self._diagnostic_counters[mode]["inactive_self_distillation"] = 0 + + if success_fraction <= tolerance: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="no_successful_rollouts", + message=( + "SDPO did not find any successful rollouts in the current generation groups. " + "If this persists, reduce task difficulty, adjust reward shaping, or lower " + "`success_reward_threshold`." + ), + ) + else: + self._diagnostic_counters[mode]["no_successful_rollouts"] = 0 + + def _compute_loss( + self, + model, + inputs, + ) -> torch.Tensor: + accumulation_scale = self.current_gradient_accumulation_steps if self.model.training else 1.0 + + if self.args.sdpo_policy_loss_mode == "hybrid": + base_policy_loss = super()._compute_loss(model, inputs) + if self.args.distillation_weight <= 0.0: + return base_policy_loss + + sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale + return base_policy_loss + self.args.distillation_weight * sdpo_loss + + if self.args.distillation_weight <= 0.0: + return super()._compute_loss(model, inputs) + + sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale + return self.args.distillation_weight * sdpo_loss diff --git a/trl/experimental/self_distillation/__init__.py b/trl/experimental/self_distillation/__init__.py new file mode 100644 index 00000000000..1449db2f7a3 --- /dev/null +++ b/trl/experimental/self_distillation/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .self_distillation_config import SelfDistillationConfig +from .self_distillation_mixin import SelfDistillationMixin + + +__all__ = ["SelfDistillationConfig", "SelfDistillationMixin"] diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py new file mode 100644 index 00000000000..a5ff5a70849 --- /dev/null +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -0,0 +1,324 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared online self-distillation trainer scaffold. + +This base combines the generic Trainer setup for self-distillation with the online rollout utilities used by SDPO-like +methods. Offline methods such as SDFT stay on `_BaseTrainer` directly and only reuse the shared distillation mixin. +""" + +from __future__ import annotations + +import inspect +from collections import defaultdict +from functools import partial +from typing import Any + +import datasets +import torch +from accelerate.logging import get_logger +from datasets import Dataset, IterableDataset +from torch import nn +from torch.utils.data import DataLoader, Sampler +from transformers import ( + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available, is_peft_available + +from ...models import prepare_deepspeed, prepare_fsdp +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + get_config_model_id, + identity, + split_tensor_dict, +) +from ..utils import prepare_peft_model +from .online_rollout_mixin import OnlineRolloutMixin +from .self_distillation_config import SelfDistillationConfig +from .self_distillation_mixin import SelfDistillationMixin + + +if is_peft_available(): + from peft import PeftConfig + + +logger = get_logger(__name__) + + +class BaseSelfDistillationTrainer(OnlineRolloutMixin, SelfDistillationMixin, _BaseTrainer): + """Shared scaffold for experimental self-distillation trainers without GRPO inheritance.""" + + def __init__( + self, + model: str | PreTrainedModel | nn.Module, + reward_funcs: Any | list[Any] | None = None, + args: SelfDistillationConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: PeftConfig | None = None, + ): + if train_dataset is None: + raise ValueError("`train_dataset` is required") + if args.use_vllm: + raise NotImplementedError("Self-distillation trainers do not support `use_vllm=True` yet.") + + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + elif args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the self-distillation config, but `model` is already " + "instantiated. The `model_init_kwargs` will be ignored." + ) + + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): + model = prepare_peft_model(model, peft_config, args) + + if processing_class is None: + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) + + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + self.temperature = args.temperature + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.num_generations = args.num_generations + self.num_generations_eval = args.num_generations_eval or args.num_generations + self.num_iterations = args.num_iterations + self.shuffle_dataset = args.shuffle_dataset + self.loss_type = args.loss_type + self.importance_sampling_level = args.importance_sampling_level + self.scale_rewards = args.scale_rewards + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high + self.beta = args.beta + self.mask_truncated_completions = args.mask_truncated_completions + self.chat_template_kwargs = args.chat_template_kwargs or {} + self._step = 0 + self._buffered_inputs = None + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._diagnostic_counters = { + "train": defaultdict(int), + "eval": defaultdict(int), + } + + generation_kwargs = { + "max_new_tokens": self.max_completion_length, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "repetition_penalty": args.repetition_penalty, + "cache_implementation": args.cache_implementation, + } + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**generation_kwargs) + + if hasattr(model, "warnings_issued"): + model.warnings_issued["estimate_tokens"] = True + + super().__init__( + model=model, + args=args, + data_collator=identity, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + compute_loss_func="non-None value to disable scaling", + ) + + if reward_funcs is None: + reward_funcs = [] + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_model_init_kwargs = args.model_init_kwargs or {} + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + reward_model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, + num_labels=1, + **reward_model_init_kwargs, + ) + if isinstance(reward_funcs[i], nn.Module): + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + if args.reward_weights is not None: + if len(args.reward_weights) != len(self.reward_funcs): + raise ValueError("Number of reward weights must match number of reward functions") + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) + + if reward_processing_classes is None: + reward_processing_classes = [None] * len(self.reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(self.reward_funcs): + raise ValueError("Number of reward processing classes must match number of reward functions") + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, self.reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + if args.disable_dropout: + disable_dropout_in_model(self.model) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, nn.Module): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + elif self.is_fsdp_enabled: + self.reward_funcs[i] = prepare_fsdp(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + self.model_accepts_loss_kwargs = False + self.ref_model = None + self.teacher_model = None + if args.sync_ref_model: + raise ValueError( + "sync_ref_model is not supported on the shared online self-distillation base without `ref_model`." + ) + + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_train_sampler(self, dataset=None) -> Sampler: + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=getattr(self, "num_generations_eval", self.num_generations), + seed=self.args.seed, + ) + + def training_step(self, model, inputs, num_items_in_batch): + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + return output + + def _prepare_inputs(self, generation_batch): + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + generation_batch = self._build_buffered_batch(generation_batch) + self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) + self._dispatch_self_distillation_callback( + "on_generation_batch_built", + generate_every=generate_every, + steps_per_generation=self.args.steps_per_generation, + ) + return self._buffered_inputs[self._step % self.args.steps_per_generation] + return self._build_buffered_batch(generation_batch) + + def _prepare_auxiliary_model_for_eval(self, aux_model: nn.Module): + if self.is_deepspeed_enabled: + return prepare_deepspeed(aux_model, self.accelerator) + if self.is_fsdp_enabled: + return prepare_fsdp(aux_model, self.accelerator) + return self.accelerator.prepare_model(aux_model, evaluation_mode=True) diff --git a/trl/experimental/self_distillation/online_rollout_mixin.py b/trl/experimental/self_distillation/online_rollout_mixin.py new file mode 100644 index 00000000000..756f66072b9 --- /dev/null +++ b/trl/experimental/self_distillation/online_rollout_mixin.py @@ -0,0 +1,354 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Online rollout helpers for experimental self-distillation trainers. + +This mixin owns generation, reward scoring, grouped reward normalization, and online policy-loss plumbing. It is paired +with `BaseSelfDistillationTrainer` for SDPO-style methods and intentionally kept separate from the generic distillation +loss logic in `self_distillation_mixin.py`. +""" + +from __future__ import annotations + +import torch +from torch import nn +from transformers.utils import logging + +from ...data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from ...models import unwrap_model_for_generation +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import pad + + +logger = logging.get_logger(__name__) + + +class OnlineRolloutMixin: + """Online rollout, reward, and policy-loss utilities shared by SDPO-like trainers.""" + + def _apply_prompt_template(self, prompts): + return [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] + for prompt in prompts + ] + + def _build_buffered_batch(self, generation_batch): + return self._generate_and_score_completions(generation_batch) + + def _generate(self, prompts): + # Keep the generation path aligned with the reference trainers: generate from left-padded prompts, + # then recover completion token spans by trimming prompt tokens and stopping at the first EOS. + prompts_text = self._apply_prompt_template(prompts) + generate_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + # This path already receives tokenized model inputs. Bypass the buffered trainer hook and use the plain + # tensor/device preparation from `_BaseTrainer`. + generate_inputs = _BaseTrainer._prepare_inputs(self, generate_inputs) + with ( + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + ) as unwrapped_model, + torch.no_grad(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, + generation_config=self.generation_config, + disable_compile=True, + ) + prompt_ids = generate_inputs["input_ids"] + prompt_mask = generate_inputs["attention_mask"] + prompt_length = prompt_ids.size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) + completion_mask = (seq_idx <= eos_idx.unsqueeze(1)).int() + prompt_ids_list = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=False)] + completion_ids_list = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=False)] + return prompt_ids_list, completion_ids_list + + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + if len(self.reward_funcs) == 0: + return torch.zeros((len(prompts), 0), device=device) + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + reward_kwargs["trainer_state"] = self.state + + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, strict=True) + ): + if isinstance(reward_func, nn.Module): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, + ) + # Reward functions operate on tokenized tensors too, so they need the base Trainer input preparation + # rather than the outer buffered generation hook. + reward_inputs = _BaseTrainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] + else: + output_reward_func = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + return self.accelerator.gather(rewards_per_func) + + def _generate_and_score_completions(self, inputs): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + prompts = [x["prompt"] for x in inputs] + prompt_ids_list, completion_ids_list = self._generate(prompts) + + prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) + completion_ids = [torch.tensor(ids) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) + completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) + + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + + with torch.no_grad(): + generate_every = self.args.steps_per_generation * self.num_iterations + if self.args.gradient_accumulation_steps % generate_every != 0: + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ) + else: + old_per_token_logps = None + + if is_conversational({"prompt": prompts[0]}): + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in completions_text] + else: + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + if rewards_per_func.numel() == 0: + rewards = torch.zeros(self.accelerator.num_processes * len(prompts), device=device) + else: + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) + if self.scale_rewards == "batch": + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + elif self.scale_rewards == "none": + std_rewards = torch.ones_like(rewards) + group_std_rewards = torch.ones(rewards.numel() // num_generations, device=device, dtype=rewards.dtype) + else: + group_std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = group_std_rewards.repeat_interleave(num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4) + self._record_reward_diagnostics(mode, rewards, rewards_per_func, group_std_rewards) + + local_batch_size = completion_ids.size(0) + process_start = self.accelerator.process_index * local_batch_size + process_slice = slice(process_start, process_start + local_batch_size) + rewards = rewards[process_slice] + advantages = advantages[process_slice] + + agg_completion_lengths = self.accelerator.gather( + torch.tensor([len(ids) for ids in completion_ids_list], device=device) + ) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "rewards": rewards, + "advantages": advantages, + "num_items_in_batch": completion_mask.sum().detach(), + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + + self._dispatch_self_distillation_callback( + "on_self_distillation_batch_prepared", + old_per_token_logps=old_per_token_logps, + prompt_ids=prompt_ids, + completion_ids=completion_ids, + ) + return output + + def _record_reward_diagnostics( + self, + mode: str, + rewards: torch.Tensor, + rewards_per_func: torch.Tensor, + group_std_rewards: torch.Tensor, + ) -> None: + tolerance = self.args.diagnostics_flat_tolerance + + reward_mean = rewards.mean() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_std = rewards.std() if rewards.numel() > 1 else torch.tensor(0.0, device=self.accelerator.device) + reward_min = rewards.min() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + reward_max = rewards.max() if rewards.numel() > 0 else torch.tensor(0.0, device=self.accelerator.device) + flat_group_fraction = ( + (group_std_rewards <= tolerance).float().mean() + if group_std_rewards.numel() > 0 + else torch.tensor(1.0, device=self.accelerator.device) + ) + + self._metrics[mode]["self_distillation/reward_mean"].append(self.accelerator.gather(reward_mean).mean().item()) + self._metrics[mode]["self_distillation/reward_std"].append(self.accelerator.gather(reward_std).mean().item()) + self._metrics[mode]["self_distillation/reward_min"].append(self.accelerator.gather(reward_min).min().item()) + self._metrics[mode]["self_distillation/reward_max"].append(self.accelerator.gather(reward_max).max().item()) + self._metrics[mode]["self_distillation/group_reward_std_mean"].append( + self.accelerator.gather(group_std_rewards.mean() if group_std_rewards.numel() > 0 else reward_std) + .mean() + .item() + ) + self._metrics[mode]["self_distillation/flat_group_fraction"].append( + self.accelerator.gather(flat_group_fraction).mean().item() + ) + + if rewards_per_func.numel() > 0: + reward_func_means = rewards_per_func.nanmean(dim=0) + gathered_means = self.accelerator.gather(reward_func_means).view(-1, reward_func_means.numel()).mean(dim=0) + for reward_name, reward_func_mean in zip(self.reward_func_names, gathered_means.tolist(), strict=True): + self._metrics[mode][f"self_distillation/rewards/{reward_name}"].append(reward_func_mean) + + reward_is_flat = reward_std.item() <= tolerance + grouped_rewards_are_flat = flat_group_fraction.item() >= 1.0 - tolerance + if reward_is_flat and grouped_rewards_are_flat: + self._warn_on_degenerate_diagnostics( + mode=mode, + counter_key="flat_rewards", + message=( + "Observed flat SDPO rewards across all sampled generations. " + "Policy advantages will collapse to zero, and SDPO will not learn. " + "Check reward density, reward shaping, or `success_reward_threshold`." + ), + ) + else: + self._diagnostic_counters[mode]["flat_rewards"] = 0 + + def _warn_on_degenerate_diagnostics(self, mode: str, counter_key: str, message: str) -> None: + interval = self.args.diagnostics_warning_interval + if interval == 0: + return + + self._diagnostic_counters[mode][counter_key] += 1 + count = self._diagnostic_counters[mode][counter_key] + if count == 1 or count % interval == 0: + logger.warning("%s Consecutive degenerate steps: %s.", message, count) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError(f"The {self.__class__.__name__} does not support returning outputs") + return self._compute_loss(model, inputs) + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + if not isinstance(inputs, dict): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + return loss.detach(), None, None + + def _compute_loss(self, model, inputs): + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + per_token_logps, _ = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ) + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "sequence": + log_ratio = (log_ratio * completion_mask).sum(-1, keepdim=True) / completion_mask.sum( + -1, keepdim=True + ).clamp(min=1.0) + coef_1 = torch.exp(log_ratio) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss = -torch.min(coef_1 * advantages, coef_2 * advantages) + + loss = self._aggregate_self_distillation_loss(per_token_loss, completion_mask) + + mode = "train" if self.model.training else "eval" + self._metrics[mode]["self_distillation/policy_loss"].append( + self.accelerator.gather(loss.detach()).mean().item() + ) + + accumulation_scale = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + return loss / accumulation_scale diff --git a/trl/experimental/self_distillation/peft_adapter_ema_callback.py b/trl/experimental/self_distillation/peft_adapter_ema_callback.py new file mode 100644 index 00000000000..e252bb512a4 --- /dev/null +++ b/trl/experimental/self_distillation/peft_adapter_ema_callback.py @@ -0,0 +1,145 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + + +logger = logging.getLogger(__name__) + + +class PEFTAdapterEMACallback(TrainerCallback): + """ + Callback that maintains an EMA copy of PEFT adapter weights for use as a teacher model in self-distillation. + + The callback creates a secondary adapter ("teacher") with zero-initialized weights and maintains shadow weights + that are updated via exponential moving average: `teacher_weight = (1-α) * teacher_weight + α * student_weight` + + Usage: + ```python + trainer.add_callback( + PEFTAdapterEMACallback( + model=model, + teacher_adapter_name="teacher", + update_rate=0.05, + ) + ) + ``` + """ + + def __init__( + self, + model, + teacher_adapter_name: str = "teacher", + update_rate: float = 0.05, + sync_steps: int = 1, + accelerator=None, + ): + self.model = model + self.teacher_adapter_name = teacher_adapter_name + self.update_rate = update_rate + self.sync_steps = sync_steps + self.accelerator = accelerator + self.shadow_weights: dict[str, torch.Tensor] | None = None + self.teacher_adapter_config = None + self._initialized = False + + def _get_student_state_dict(self): + """Get student adapter state dict using PEFT keys (without adapter name).""" + from peft import get_peft_model_state_dict + + if self.accelerator is not None: + model = self.accelerator.unwrap_model(self.model) + else: + model = self.model + return get_peft_model_state_dict(model) + + def _initialize_teacher_adapter(self): + """Create teacher adapter with zero weights initialized from student adapter.""" + from peft import get_peft_model_state_dict, set_peft_model_state_dict + + if self._initialized: + return + + if self.accelerator is not None: + model = self.accelerator.unwrap_model(self.model) + else: + model = self.model + + adapter_name = model.active_adapter + if adapter_name is None: + adapter_name = "default" + + self.teacher_adapter_config = model.peft_config.get(adapter_name) + + student_state = get_peft_model_state_dict(model) + + teacher_state = {k: torch.zeros_like(v) for k, v in student_state.items()} + + model.add_adapter(self.teacher_adapter_name, self.teacher_adapter_config) + + model.set_adapter(self.teacher_adapter_name) + set_peft_model_state_dict(model, teacher_state, adapter_name=self.teacher_adapter_name) + + model.set_adapter(adapter_name) + + self.shadow_weights = {k: v.clone().zero_() for k, v in teacher_state.items()} + + self._initialized = True + logger.info(f"Initialized PEFT adapter EMA teacher with adapter name: {self.teacher_adapter_name}") + + @torch.no_grad() + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if state.global_step % self.sync_steps != 0: + return + + if not self._initialized: + self._initialize_teacher_adapter() + + if self.shadow_weights is None: + return + + if self.accelerator is None and "accelerator" in kwargs: + self.accelerator = kwargs["accelerator"] + + student_state = self._get_student_state_dict() + + for key, student_param in student_state.items(): + if key in self.shadow_weights: + shadow = self.shadow_weights[key] + shadow.data = (1 - self.update_rate) * shadow.data + self.update_rate * student_param.data + + from peft import set_peft_model_state_dict + + if self.accelerator is not None: + unwrapped_model = self.accelerator.unwrap_model(self.model) + else: + unwrapped_model = self.model + + original_adapter = unwrapped_model.active_adapter + unwrapped_model.set_adapter(self.teacher_adapter_name) + set_peft_model_state_dict(unwrapped_model, self.shadow_weights, adapter_name=self.teacher_adapter_name) + unwrapped_model.set_adapter(original_adapter) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.accelerator is None and "accelerator" in kwargs: + self.accelerator = kwargs["accelerator"] + self._initialize_teacher_adapter() diff --git a/trl/experimental/self_distillation/self_distillation_config.py b/trl/experimental/self_distillation/self_distillation_config.py new file mode 100644 index 00000000000..b0e9cf792f4 --- /dev/null +++ b/trl/experimental/self_distillation/self_distillation_config.py @@ -0,0 +1,308 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + +from ...trainer.base_config import _BaseConfig + + +@dataclass +class SelfDistillationConfig(_BaseConfig): + r""" + Shared configuration for experimental self-distillation trainers. + + This class contains only the arguments that are specific to the shared self-distillation stack. For the full set of + generic training arguments, refer to [`~transformers.TrainingArguments`] via + [`trl.trainer.base_config._BaseConfig`]. + + Parameters: + > Parameters that control generation and rollout reuse + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments used when the `model` argument is passed as a string. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum prompt length. Longer prompts are truncated from the left. + num_generations (`int`, *optional*, defaults to `8`): + Number of sampled generations per prompt. + generation_batch_size (`int` or `None`, *optional*): + Global batch size used for generation. Mutually exclusive with `steps_per_generation`. + steps_per_generation (`int` or `None`, *optional*): + Number of optimizer steps that reuse one generated batch. Mutually exclusive with `generation_batch_size`. + + > Parameters that control the online policy objective + + beta (`float`, *optional*, defaults to `0.0`): + Reference-model KL coefficient for online policy optimization. + loss_type (`str`, *optional*, defaults to `"dapo"`): + Policy-loss aggregation mode. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`. + scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): + Reward normalization mode. Supported: `group`, `batch`, `none`. + + > Parameters that control self-distillation + + distillation_alpha (`float`, *optional*, defaults to `0.5`): + Divergence interpolation coefficient using the official SDPO/SDFT convention: `0.0=forward KL`, `0.5=JSD`, + `1.0=reverse KL`. + distillation_topk (`int` or `None`, *optional*, defaults to `100`): + Number of top tokens to keep for top-k distillation. If `None`, all logits are used. + full_logit_distillation (`bool`, *optional*, defaults to `False`): + Whether to use full-logit distillation instead of token-level distillation. + distillation_is_clip (`float` or `None`, *optional*, defaults to `2.0`): + Importance-sampling clip used by the official SDPO-style correction. `None` disables clipping. + distillation_weight (`float`, *optional*, defaults to `1.0`): + Weight applied to the self-distillation loss term. + + > Parameters that control diagnostics + + diagnostics_warning_interval (`int`, *optional*, defaults to `10`): + Emit repeated trainer diagnostics every N consecutive degenerate steps. Set to `0` to disable. + diagnostics_flat_tolerance (`float`, *optional*, defaults to `1e-8`): + Tolerance used to decide whether reward variance or reprompt activity is effectively zero. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={"help": "Keyword arguments for model initialization when `model` is passed as a string."}, + ) + disable_dropout: bool = field( + default=False, + metadata={"help": "Whether to disable dropout in the student model."}, + ) + remove_unused_columns: bool = field( + default=False, + metadata={"help": "Whether to drop dataset columns unused by the trainer."}, + ) + max_prompt_length: int | None = field( + default=512, + metadata={"help": "Maximum prompt length. Longer prompts are truncated from the left."}, + ) + num_generations: int = field( + default=8, + metadata={"help": "Number of sampled generations per prompt."}, + ) + num_generations_eval: int | None = field( + default=None, + metadata={"help": "Number of sampled generations per prompt during evaluation."}, + ) + max_completion_length: int | None = field( + default=256, + metadata={"help": "Maximum generated completion length."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={"help": "Whether to gather ZeRO-3 weights for generation."}, + ) + shuffle_dataset: bool = field( + default=True, + metadata={"help": "Whether to shuffle the training dataset."}, + ) + generation_batch_size: int | None = field( + default=None, + metadata={"help": "Global batch size used for generation. Mutually exclusive with `steps_per_generation`."}, + ) + steps_per_generation: int | None = field( + default=None, + metadata={"help": "Number of optimizer steps that reuse one generated batch."}, + ) + temperature: float = field( + default=1.0, + metadata={"help": "Sampling temperature."}, + ) + top_p: float = field( + default=1.0, + metadata={"help": "Top-p sampling parameter."}, + ) + top_k: int = field( + default=0, + metadata={"help": "Top-k sampling parameter. `0` disables top-k filtering."}, + ) + min_p: float | None = field( + default=None, + metadata={"help": "Minimum token probability for sampling."}, + ) + generation_kwargs: dict[str, Any] | None = field( + default=None, + metadata={"help": "Extra generation kwargs passed to `GenerationConfig`."}, + ) + chat_template_kwargs: dict[str, Any] | None = field( + default=None, + metadata={"help": "Extra kwargs forwarded to chat template application."}, + ) + repetition_penalty: float = field( + default=1.0, + metadata={"help": "Repetition penalty used during generation."}, + ) + use_transformers_paged: bool = field( + default=False, + metadata={"help": "Reserved for paged generation support."}, + ) + cache_implementation: str | None = field( + default=None, + metadata={"help": "Cache implementation used by transformers generation."}, + ) + use_vllm: bool = field( + default=False, + metadata={"help": "Whether to use vLLM for generation."}, + ) + beta: float = field( + default=0.0, + metadata={"help": "Reference-model KL coefficient for online policy optimization."}, + ) + num_iterations: int = field( + default=1, + metadata={"help": "Number of optimization iterations per generated batch."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Lower clipping coefficient for GRPO-style policy loss."}, + ) + epsilon_high: float | None = field( + default=None, + metadata={"help": "Upper clipping coefficient. Defaults to `epsilon` when unset."}, + ) + importance_sampling_level: str = field( + default="token", + metadata={"help": "Importance-sampling granularity. Supported: `token`, `sequence`."}, + ) + reward_weights: list[float] | None = field( + default=None, + metadata={"help": "Optional weights for multiple reward functions."}, + ) + scale_rewards: str | bool = field( + default="group", + metadata={"help": "Reward normalization mode. Supported: `group`, `batch`, `none`."}, + ) + loss_type: str = field( + default="dapo", + metadata={"help": "Policy loss aggregation. Supported: `grpo`, `bnpo`, `dr_grpo`, `dapo`."}, + ) + mask_truncated_completions: bool = field( + default=False, + metadata={"help": "Whether to exclude truncated completions from the loss."}, + ) + sync_ref_model: bool = field( + default=False, + metadata={"help": "Whether to synchronize the reference model with the student model."}, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={"help": "EMA mix coefficient used when syncing the reference model."}, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={"help": "How often to synchronize the reference model."}, + ) + top_entropy_quantile: float = field( + default=1.0, + metadata={"help": "Reserved for entropy-based token filtering."}, + ) + distillation_alpha: float = field( + default=0.5, + metadata={"help": "KL divergence direction: 0.0=forward KL, 0.5=JSD, 1.0=reverse KL."}, + ) + distillation_topk: int | None = field( + default=100, + metadata={"help": "Number of top tokens for top-k distillation. If None, uses all tokens."}, + ) + full_logit_distillation: bool = field( + default=False, + metadata={"help": "Whether to use full-logit distillation instead of token-level distillation."}, + ) + distillation_is_clip: float | None = field( + default=2.0, + metadata={"help": "Clipping coefficient for importance sampling in self-distillation."}, + ) + distillation_add_tail: bool = field( + default=False, + metadata={"help": "Whether to add a tail bucket for non-top-k probability mass."}, + ) + distillation_weight: float = field( + default=1.0, + metadata={"help": "Weight applied to the self-distillation loss term."}, + ) + diagnostics_warning_interval: int = field( + default=10, + metadata={ + "help": "Emit repeated trainer diagnostics every N consecutive degenerate steps. Set to 0 to disable." + }, + ) + diagnostics_flat_tolerance: float = field( + default=1e-8, + metadata={ + "help": "Tolerance used to decide whether reward variance or reprompt activity is effectively zero." + }, + ) + + def __post_init__(self): + super().__post_init__() + + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + if self.scale_rewards not in ["group", "batch", "none"]: + raise ValueError("scale_rewards must be one of: 'group', 'batch', 'none'") + + if self.importance_sampling_level not in ["token", "sequence"]: + raise ValueError("importance_sampling_level must be either 'token' or 'sequence'") + if self.loss_type not in ["grpo", "bnpo", "dr_grpo", "dapo"]: + raise ValueError("loss_type must be one of: 'grpo', 'bnpo', 'dr_grpo', 'dapo'") + if self.num_generations < 1: + raise ValueError("num_generations must be at least 1") + if not 0.0 <= self.distillation_alpha <= 1.0: + raise ValueError("distillation_alpha must be in [0, 1]") + if self.distillation_topk is not None and self.distillation_topk <= 0: + raise ValueError("distillation_topk must be positive when provided") + if self.distillation_is_clip is not None and self.distillation_is_clip <= 0: + raise ValueError("distillation_is_clip must be positive when provided") + if self.distillation_weight < 0: + raise ValueError("distillation_weight must be non-negative") + if self.diagnostics_warning_interval < 0: + raise ValueError("diagnostics_warning_interval must be non-negative") + if self.diagnostics_flat_tolerance < 0: + raise ValueError("diagnostics_flat_tolerance must be non-negative") + + num_processes = self.world_size + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + global_batch_size = self.per_device_train_batch_size * num_processes + if self.generation_batch_size % global_batch_size != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size ({global_batch_size})." + ) + self.steps_per_generation = self.generation_batch_size // global_batch_size + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + else: + raise ValueError("'generation_batch_size' and 'steps_per_generation' can not both be configured") + + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations ({self.num_generations})." + ) + + if self.do_eval and self.eval_strategy != "no": + num_generations_eval = self.num_generations_eval or self.num_generations + if (self.per_device_eval_batch_size * num_processes) % num_generations_eval != 0: + raise ValueError( + f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " + f"divisible by the number of generations used for evaluation ({num_generations_eval})." + ) + + if self.epsilon_high is None: + self.epsilon_high = self.epsilon diff --git a/trl/experimental/self_distillation/self_distillation_mixin.py b/trl/experimental/self_distillation/self_distillation_mixin.py new file mode 100644 index 00000000000..fb2a8808de1 --- /dev/null +++ b/trl/experimental/self_distillation/self_distillation_mixin.py @@ -0,0 +1,295 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared self-distillation loss utilities used by experimental trainers. + +This module intentionally holds only the reusable distillation mechanics: callback dispatch, common prompt/context +helpers, and the student-vs-teacher loss computation. Trainer lifecycle and online rollout concerns live in the trainer +classes or their online-specific base. +""" + +from __future__ import annotations + +from contextlib import nullcontext +from typing import Any + +import torch +import torch.nn.functional as F + +from ...trainer.utils import entropy_from_logits, selective_log_softmax +from .self_distillation_config import SelfDistillationConfig + + +class SelfDistillationMixin: + """Reusable self-distillation helpers shared across experimental trainers.""" + + config_cls = SelfDistillationConfig + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + self._signature_columns = ["prompt", "privileged_context"] + + def _dispatch_self_distillation_callback(self, event_name: str, **payload) -> None: + for callback in self.callback_handler.callbacks: + callback_fn = getattr(callback, event_name, None) + if callback_fn is not None: + callback_fn( + args=self.args, + state=self.state, + control=self.control, + model=self.model, + processing_class=self.processing_class, + **payload, + ) + + @staticmethod + def _split_prompt_and_privileged_context(inputs: list[dict[str, Any]]) -> tuple[list[Any], list[Any]]: + prompts = [example["prompt"] for example in inputs] + privileged_contexts = [example.get("privileged_context") for example in inputs] + return prompts, privileged_contexts + + def _allow_topk_without_full_logit_distillation(self) -> bool: + return True + + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=False, + ): + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} + if "logits_to_keep" in self.model_kwarg_keys: + model_inputs["logits_to_keep"] = logits_to_keep + 1 + logits = model(**model_inputs).logits + logits = logits[:, :-1, :] + logits = logits[:, -logits_to_keep:, :] + logits = logits / self.temperature + completion_ids = input_ids[:, -logits_to_keep:] + selected_logps = selective_log_softmax(logits, completion_ids) + entropies = entropy_from_logits(logits) if compute_entropy else None + return selected_logps, entropies + + def _compute_self_distillation_loss( + self, + model, + inputs: dict[str, Any], + ) -> torch.Tensor: + # Expected batch contract: + # - required: `prompt_ids`, `prompt_mask`, `completion_ids`, `completion_mask`, + # `teacher_input_ids`, `teacher_attention_mask` + # - optional: `self_distillation_mask` to zero-out samples without teacher supervision, + # `old_per_token_logps` to enable IS clipping when generation and optimization are misaligned + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + logits_to_keep = completion_ids.size(1) + + self_distillation_mask = inputs.get("self_distillation_mask") + if self_distillation_mask is not None: + response_mask = completion_mask * self_distillation_mask.unsqueeze(1) + else: + response_mask = completion_mask + + if response_mask.sum() == 0: + mode = "train" if model.training else "eval" + self._log_self_distillation_metric(mode, "distillation_loss", 0.0) + return torch.tensor(0.0, device=completion_ids.device, requires_grad=True) + + student_input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + student_attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + student_model_inputs = { + "input_ids": student_input_ids, + "attention_mask": student_attention_mask, + "use_cache": False, + } + if "logits_to_keep" in self.model_kwarg_keys: + student_model_inputs["logits_to_keep"] = logits_to_keep + 1 + + student_logits = model(**student_model_inputs).logits + student_logits = student_logits[:, :-1, :] + student_logits = student_logits[:, -logits_to_keep:, :] + student_logits = student_logits / self.temperature + + teacher_input_ids = inputs["teacher_input_ids"] + teacher_attention_mask = inputs["teacher_attention_mask"] + teacher_model_inputs = { + "input_ids": teacher_input_ids, + "attention_mask": teacher_attention_mask, + "use_cache": False, + } + if "logits_to_keep" in self.model_kwarg_keys: + teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1 + + teacher_model = self._get_teacher_model_for_self_distillation(model) + with torch.no_grad(), self._get_teacher_context_for_self_distillation(model): + teacher_logits = teacher_model(**teacher_model_inputs).logits + teacher_logits = teacher_logits[:, :-1, :] + teacher_logits = teacher_logits[:, -logits_to_keep:, :] + teacher_logits = teacher_logits / self.temperature + + use_topk_distillation = self.args.distillation_topk is not None and ( + self.args.full_logit_distillation or self._allow_topk_without_full_logit_distillation() + ) + if use_topk_distillation: + student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) + topk_student_logits, topk_indices = torch.topk(student_logits, k=self.args.distillation_topk, dim=-1) + topk_student_log_probs = topk_student_logits - student_logsumexp + + teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) + topk_teacher_logits = torch.gather(teacher_logits, dim=-1, index=topk_indices) + topk_teacher_log_probs = topk_teacher_logits - teacher_logsumexp + + if self.args.distillation_add_tail: + topk_student_log_probs = self._add_tail(topk_student_log_probs) + topk_teacher_log_probs = self._add_tail(topk_teacher_log_probs) + else: + topk_student_log_probs = self._renorm_topk_log_probs(topk_student_log_probs) + topk_teacher_log_probs = self._renorm_topk_log_probs(topk_teacher_log_probs) + + per_token_loss = self._compute_divergence( + topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha + ) + elif self.args.full_logit_distillation: + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + per_token_loss = self._compute_divergence( + student_log_probs, teacher_log_probs, self.args.distillation_alpha + ) + else: + if self.args.distillation_alpha != 1.0: + raise ValueError( + "Only reverse KL (alpha=1.0) is supported for token-level distillation when " + "`full_logit_distillation=False`, " + f"got alpha={self.args.distillation_alpha}" + ) + student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) + teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) + idx = completion_ids.unsqueeze(-1) + student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_logsumexp).squeeze(-1) + teacher_per_token_logps = (torch.gather(teacher_logits, dim=-1, index=idx) - teacher_logsumexp).squeeze(-1) + per_token_loss = self._compute_token_level_distillation_loss( + student_per_token_logps, teacher_per_token_logps + ) + + if self.args.distillation_is_clip is not None: + old_log_probs = inputs.get("old_per_token_logps") + if old_log_probs is not None: + with torch.no_grad(): + student_lse = torch.logsumexp(student_logits, dim=-1, keepdim=True) + idx = completion_ids.unsqueeze(-1) + student_per_token_logps = (torch.gather(student_logits, dim=-1, index=idx) - student_lse).squeeze( + -1 + ) + per_token_loss = self._apply_importance_sampling_clipping( + per_token_loss, student_per_token_logps, old_log_probs, self.args.distillation_is_clip + ) + + loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask) + + mode = "train" if model.training else "eval" + mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + self._log_self_distillation_metric( + mode, + "distillation_loss", + self.accelerator.gather(mean_distill_loss).mean().item(), + ) + + return loss + + def _get_teacher_model_for_self_distillation(self, model): + teacher_model = getattr(self, "teacher_model", None) + if teacher_model is None: + return model + return teacher_model + + def _get_teacher_context_for_self_distillation(self, model): + return nullcontext() + + def _log_self_distillation_metric(self, mode: str, metric_name: str, value: float) -> None: + metric_prefix = getattr(self, "_name", "self_distillation").lower().replace(" ", "_") + self._metrics[mode][f"self_distillation/{metric_name}"].append(value) + self._metrics[mode][f"{metric_prefix}/{metric_name}"].append(value) + + @staticmethod + def _compute_divergence( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + alpha: float, + ) -> torch.Tensor: + if alpha == 0.0: + kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif alpha == 1.0: + kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + alpha_t = torch.tensor(alpha, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture = torch.logsumexp( + torch.stack([student_log_probs + torch.log(1 - alpha_t), teacher_log_probs + torch.log(alpha_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) + kl = torch.lerp(kl_student, kl_teacher, alpha) + return kl.sum(-1) + + @staticmethod + def _add_tail(log_probs: torch.Tensor) -> torch.Tensor: + log_s = torch.logsumexp(log_probs, dim=-1, keepdim=True) + log_s = torch.clamp(log_s, max=-1e-7) + tail_log = torch.log(-torch.expm1(log_s)) + return torch.cat([log_probs, tail_log], dim=-1) + + @staticmethod + def _renorm_topk_log_probs(log_probs: torch.Tensor) -> torch.Tensor: + return log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True) + + @staticmethod + def _compute_token_level_distillation_loss( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + ) -> torch.Tensor: + # This is the token-level reverse-KL surrogate used by the official SDPO implementation for + # `full_logit_distillation=False`. It intentionally treats the teacher log-probs as fixed targets + # and keeps only the score-function term for the sampled student tokens. + log_ratio = student_log_probs - teacher_log_probs + return log_ratio.detach() * student_log_probs + + @staticmethod + def _apply_importance_sampling_clipping( + per_token_loss: torch.Tensor, + student_log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + clip_coeff: float, + ) -> torch.Tensor: + negative_approx_kl = (student_log_probs - old_log_probs).detach() + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl).clamp(max=clip_coeff) + return per_token_loss * ratio + + def _aggregate_self_distillation_loss( + self, + per_token_loss: torch.Tensor, + response_mask: torch.Tensor, + ) -> torch.Tensor: + loss_type = self.loss_type + if loss_type == "grpo": + loss = (per_token_loss * response_mask).sum(-1) / response_mask.sum(-1).clamp(min=1.0) + return loss.mean() + if loss_type == "bnpo": + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + if loss_type == "dr_grpo": + return (per_token_loss * response_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + if loss_type in ["dapo", "luspo", "cispo", "sapo"]: + return (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) + raise ValueError(f"Unsupported loss_type for self-distillation: {loss_type}") diff --git a/trl/experimental/self_distillation/teacher_context.py b/trl/experimental/self_distillation/teacher_context.py new file mode 100644 index 00000000000..5e1020c91a7 --- /dev/null +++ b/trl/experimental/self_distillation/teacher_context.py @@ -0,0 +1,85 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from ...data_utils import maybe_apply_chat_template +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import pad + + +def extract_last_user_text(prompt: list[dict[str, Any]]) -> str: + """Extract the text content from the last user message in a conversational prompt.""" + last_message = prompt[-1] + if last_message.get("role") != "user": + raise ValueError( + f"Self-distillation teacher prompt construction expects the conversation to end with a user turn, " + f"but the last message has role '{last_message.get('role')}'. " + f"Prompts ending with assistant prefills or tool turns are not supported." + ) + content = last_message.get("content", "") + if isinstance(content, list): + return " ".join(part.get("text", "") for part in content if part.get("type") == "text") + return content + + +@dataclass +class TokenizedPromptBatch: + prompt_ids: torch.Tensor + prompt_mask: torch.Tensor + + +class PromptTokenizer: + """Internal helper to tokenize prompt-like inputs consistently across self-distillation trainers.""" + + def __init__(self, trainer): + self.trainer = trainer + + def apply_prompt_template(self, prompts: list[Any]) -> list[str]: + return [ + maybe_apply_chat_template( + {"prompt": prompt}, + self.trainer.processing_class, + **getattr(self.trainer, "chat_template_kwargs", {}), + )["prompt"] + for prompt in prompts + ] + + def tokenize_prompts(self, prompts: list[Any]) -> TokenizedPromptBatch: + prompt_text = self.apply_prompt_template(prompts) + prompt_inputs = self.trainer.processing_class( + text=prompt_text, + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.trainer.max_prompt_length, + truncation=True, + add_special_tokens=False, + ) + prompt_inputs = super(_BaseTrainer, self.trainer)._prepare_inputs(prompt_inputs) + prompt_ids = [ + p[m].tolist() + for p, m in zip(prompt_inputs["input_ids"], prompt_inputs["attention_mask"].bool(), strict=False) + ] + prompt_ids = [torch.tensor(ids, device=self.trainer.accelerator.device) for ids in prompt_ids] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + return TokenizedPromptBatch( + prompt_ids=pad(prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + prompt_mask=pad(prompt_mask, padding_value=0, padding_side="left"), + )