From 0cf0afb15f146b1b9d480776df48b5ace0fad16a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 28 Apr 2026 21:09:52 +0000 Subject: [PATCH 1/6] Enable chunked NLL loss with PEFT in SFT --- tests/test_sft_trainer.py | 94 ++++++++++++++++++++++++++++++++++++++ trl/trainer/sft_trainer.py | 38 ++++++++------- 2 files changed, 115 insertions(+), 17 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index bafbe746b04..4b5c3867c31 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -500,6 +500,43 @@ def test_train_chunked_nll_loss(self): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + @require_peft + def test_train_chunked_nll_loss_peft(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, loss_type="chunked_nll", report_to="none") + + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + def test_train_moe_model_with_aux_loss(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") @@ -2628,3 +2665,60 @@ def test_forward_without_labels_matches_reference(self): ref_out = ref_model(input_ids=input_ids) out = chunked_model(input_ids=input_ids) torch.testing.assert_close(out.logits, ref_out.logits, atol=1e-5, rtol=1e-5) + + @require_peft + @pytest.mark.filterwarnings("ignore:Model has `tie_word_embeddings=True`") + @pytest.mark.parametrize( + "peft_config_factory", + [ + pytest.param(lambda: LoraConfig(r=4, target_modules=["q_proj", "v_proj"]), id="lora"), + pytest.param( + lambda: LoraConfig(r=4, target_modules=["q_proj", "v_proj"], modules_to_save=["lm_head"]), + id="lora+modules_to_save", + ), + pytest.param( + lambda: PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prompt_tuning" + ), + pytest.param( + lambda: PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prompt_encoder" + ), + pytest.param( + lambda: PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prefix_tuning" + ), + ], + ) + def test_forward_matches_reference_with_peft(self, peft_config_factory): + """Patching the inner causal LM (`peft_model.get_base_model()`) must produce a forward whose loss matches + the unpatched PEFT reference for both LoRA-style (adapters live in the module tree) and prompt-learning + (`PeftModel.forward` injects virtual tokens, then delegates into the patched inner forward).""" + base = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32 + ).to(torch_device) + ref_model = get_peft_model(copy.deepcopy(base), peft_config_factory()) + chunked_model = copy.deepcopy(ref_model) + _patch_chunked_ce_lm_head(chunked_model.get_base_model(), chunk_size=self.CHUNK_SIZE) + + B, S = 2, 16 + torch.manual_seed(42) + input_ids = torch.randint(0, base.config.vocab_size, (B, S), device=torch_device) + labels = input_ids.clone() + labels[:, :4] = -100 + num_items = int((labels[..., 1:] != -100).sum()) + + ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + torch.testing.assert_close(out.loss, ref_out.loss, atol=1e-5, rtol=1e-5) + + ref_out.loss.backward() + out.loss.backward() + chunked_params = dict(chunked_model.named_parameters()) + for name, ref_param in ref_model.named_parameters(): + if not ref_param.requires_grad or ref_param.grad is None: + continue + torch.testing.assert_close( + chunked_params[name].grad, + ref_param.grad, + atol=1e-5, + rtol=1e-5, + msg=f"gradient mismatch on {name}", + ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index f5ae7e16722..3850b64be94 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -212,23 +212,24 @@ def _chunked_cross_entropy_loss( def _patch_chunked_ce_lm_head(model: torch.nn.Module, chunk_size: int) -> None: """ - Patch a causal LM so its `forward` computes the language modeling loss via [`_chunked_cross_entropy_loss`] when - `labels` are provided. - - The patched forward calls the base decoder directly (`model.get_decoder()`) to obtain hidden states, skips the - `lm_head` matmul on positions with `labels == -100`, and computes the cross-entropy in chunks of `chunk_size` valid - tokens. It returns a [`_ChunkedCELMHeadOutput`] with `loss` set, `logits` set to `None`, and `token_accuracy` / - `entropy` fields set to the mean values over non-ignored tokens. Also accepts pre-shifted `shift_labels` in place - of `labels`, for the context / sequence parallelism path. When both are `None`, the original forward is invoked so - generation and labels-free evaluation preserve any per-model logits post-processing (e.g. `logit_scale`, + Patch a causal LM so `forward` computes the language modeling loss via [`_chunked_cross_entropy_loss`]. + + When `labels` (or pre-shifted `shift_labels`, for CP/SP) are provided, the patched forward calls + `model.get_decoder()` directly, drops positions with `labels == -100` before the `lm_head` matmul, and computes the + cross-entropy in chunks of `chunk_size` valid tokens. It returns a [`_ChunkedCELMHeadOutput`] with `loss` set, + `logits` set to `None`, and `num_correct_tokens` / `entropy_sum` populated over non-ignored tokens. For MoE models + with `output_router_logits=True`, the load-balancing auxiliary loss is added with the same coefficient and formula + as the model's own forward, keeping the chunked path numerically equivalent to the reference. + + With both `labels` and `shift_labels` set to `None` the original forward runs unchanged, so generation and + labels-free evaluation preserve any per-model logits post-processing (e.g. `logit_scale`, `final_logit_softcapping`). - For MoE models with `output_router_logits=True`, the load-balancing auxiliary loss is added to the main loss with - the same coefficient (`router_aux_loss_coef`) and formula (`load_balancing_loss_func`) used by the model's own - forward, so the chunked path remains numerically equivalent to the reference. + For PEFT, pass the inner causal LM (`peft_model.get_base_model()`) rather than the `PeftModel` wrapper, so that + prompt-learning variants (PromptTuning, PrefixTuning, PTuning) keep their virtual-token injection in + `PeftModel.forward` before delegating into the patched forward. - Not supported yet: VLM / multimodal models whose forward injects visual tokens outside the base decoder, and - PEFT-wrapped models. + Not supported yet: VLM / multimodal models whose forward injects visual tokens outside the base decoder. """ final_logit_softcapping = getattr(model.config, "final_logit_softcapping", None) logit_scale = getattr(model.config, "logit_scale", 1.0) @@ -1230,9 +1231,12 @@ def __init__( # wraps the patched forward. if self._is_vlm: raise NotImplementedError("`loss_type='chunked_nll'` is not supported for VLM models yet.") - if peft_config is not None or is_peft_model(model): - raise NotImplementedError("`loss_type='chunked_nll'` is not supported with PEFT yet.") - _patch_chunked_ce_lm_head(model, chunk_size=_CHUNKED_LM_HEAD_CHUNK_SIZE) + # For PEFT, patch the inner causal LM rather than the `PeftModel` wrapper. LoRA / IA³ / + # `modules_to_save` adapters live in the module tree, so they're hit even when we bypass + # `PeftModel.forward`. Prompt-learning variants need `PeftModel.forward` to run first (to inject + # virtual tokens), then it delegates into the patched inner forward. + target = model.get_base_model() if is_peft_model(model) else model + _patch_chunked_ce_lm_head(target, chunk_size=_CHUNKED_LM_HEAD_CHUNK_SIZE) else: raise ValueError( f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll', 'dft', and " From b3abe56bb4e5ae30b3b97a6a42a54f7e055dbe30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 29 Apr 2026 00:16:37 +0000 Subject: [PATCH 2/6] fix for prompt tuning --- tests/test_sft_trainer.py | 55 ++++++++++++++++++++++++++++++-------- trl/trainer/sft_trainer.py | 38 +++++++++++++++----------- 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 4b5c3867c31..0ba8d177a5c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -2444,22 +2444,24 @@ def _reference(hidden, weight, labels, num_items_in_batch=None): def test_forward_matches_cross_entropy(self): """With no ignored tokens, chunked loss equals standard mean cross-entropy.""" hidden, weight, labels = self._inputs() - n_valid = (labels[..., 1:] != -100).sum() - loss_c, correct_c, ent_sum_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) + expected_n_valid = (labels[..., 1:] != -100).sum() + loss_c, correct_c, ent_sum_c, n_valid_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) loss_r, acc_r, ent_r = self._reference(hidden, weight, labels) torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5) - torch.testing.assert_close(correct_c / n_valid, acc_r, atol=1e-5, rtol=1e-5) - torch.testing.assert_close(ent_sum_c / n_valid, ent_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(correct_c / n_valid_c, acc_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(ent_sum_c / n_valid_c, ent_r, atol=1e-5, rtol=1e-5) + assert n_valid_c.item() == expected_n_valid.item() def test_forward_ignore_index(self): """Ignored labels are excluded from loss, accuracy and entropy (matches F.cross_entropy).""" hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3)) - n_valid = (labels[..., 1:] != -100).sum() - loss_c, correct_c, ent_sum_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) + expected_n_valid = (labels[..., 1:] != -100).sum() + loss_c, correct_c, ent_sum_c, n_valid_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) loss_r, acc_r, ent_r = self._reference(hidden, weight, labels) torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5) - torch.testing.assert_close(correct_c / n_valid, acc_r, atol=1e-5, rtol=1e-5) - torch.testing.assert_close(ent_sum_c / n_valid, ent_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(correct_c / n_valid_c, acc_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(ent_sum_c / n_valid_c, ent_r, atol=1e-5, rtol=1e-5) + assert n_valid_c.item() == expected_n_valid.item() def test_num_items_in_batch_reduction(self): """When num_items_in_batch is provided, loss is sum / num_items_in_batch.""" @@ -2503,12 +2505,13 @@ def test_all_ignored_returns_zero(self): hidden, weight, labels = self._inputs(requires_grad=True) bias = torch.zeros(self.V, dtype=torch.float32, requires_grad=True) labels[:] = -100 - loss, correct, ent_sum = _chunked_cross_entropy_loss( + loss, correct, ent_sum, n_valid = _chunked_cross_entropy_loss( hidden, weight, self.CHUNK_SIZE, labels, lm_head_bias=bias ) assert loss.item() == 0.0 assert correct.item() == 0.0 assert ent_sum.item() == 0.0 + assert n_valid.item() == 0 assert not torch.isnan(loss) # Backward must succeed even when n_valid == 0 (can happen with completion-only loss # + truncation where a whole micro-batch is masked). @@ -2520,15 +2523,16 @@ def test_all_ignored_returns_zero(self): def test_shift_labels_matches_labels(self): """`shift_labels` path (CP/SP) must match the default `labels` path after external shifting.""" hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3)) - loss_l, correct_l, ent_l = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) + loss_l, correct_l, ent_l, n_valid_l = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) # Mimic what transformers does under CP/SP: pad labels with -100, then shift. shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() - loss_s, correct_s, ent_s = _chunked_cross_entropy_loss( + loss_s, correct_s, ent_s, n_valid_s = _chunked_cross_entropy_loss( hidden, weight, self.CHUNK_SIZE, shift_labels=shift_labels ) torch.testing.assert_close(loss_s, loss_l, atol=1e-6, rtol=1e-6) torch.testing.assert_close(correct_s, correct_l, atol=1e-6, rtol=1e-6) torch.testing.assert_close(ent_s, ent_l, atol=1e-6, rtol=1e-6) + assert n_valid_s.item() == n_valid_l.item() def test_requires_labels_or_shift_labels(self): """Must provide at least one of `labels` or `shift_labels`.""" @@ -2722,3 +2726,32 @@ def test_forward_matches_reference_with_peft(self, peft_config_factory): rtol=1e-5, msg=f"gradient mismatch on {name}", ) + + @require_peft + @pytest.mark.filterwarnings("ignore:Model has `tie_word_embeddings=True`") + def test_num_valid_tokens_with_prompt_learning_peft(self): + """For prompt-learning PEFT (PromptTuning, P-Tuning), `PeftModel.forward` prepends `-100`-padded virtual + tokens before delegating into the patched inner forward. The patched output's `num_valid_tokens` must reflect + the padded labels — when original `label[0] != -100`, it counts as a valid target paired with the last + virtual token's hidden state, so it must be included in the metric denominator to keep accuracy ≤ 1.""" + base = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32 + ).to(torch_device) + peft_config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4) + chunked_model = get_peft_model(base, peft_config) + _patch_chunked_ce_lm_head(chunked_model.get_base_model(), chunk_size=self.CHUNK_SIZE) + + B, S = 2, 16 + torch.manual_seed(42) + input_ids = torch.randint(0, base.config.vocab_size, (B, S), device=torch_device) + labels = input_ids.clone() # all positions valid, including label[0] + + out = chunked_model(input_ids=input_ids, labels=labels) + + # `labels[..., 1:]` (un-padded, what compute_loss used to compute) excludes original `label[0]`, + # but the patched forward sees padded labels and counts `label[0]` as a valid target. + unpadded = int((labels[..., 1:] != -100).sum()) + # One extra valid target per sequence (original `label[0]`). + assert out.num_valid_tokens.item() == unpadded + B + # Accuracy denominator from the patched output keeps numerator/denominator aligned, so accuracy ≤ 1. + assert out.num_correct_tokens.item() <= out.num_valid_tokens.item() diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 3850b64be94..ee7ca6901c9 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -82,6 +82,7 @@ class _ChunkedCELMHeadOutput(CausalLMOutputWithPast): num_correct_tokens: torch.Tensor | None = None entropy_sum: torch.Tensor | None = None + num_valid_tokens: torch.Tensor | None = None aux_loss: torch.Tensor | None = None @@ -110,7 +111,7 @@ def _chunked_cross_entropy_loss( logit_scale: float = 1.0, final_logit_softcapping: float | None = None, lm_head_bias: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Memory-efficient next-token cross-entropy over hidden states and an `lm_head` weight. @@ -152,9 +153,9 @@ def _chunked_cross_entropy_loss( Bias of the `lm_head` linear layer, shape `(V,)`. Added to each chunk's logits when provided. Returns: - `tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: scalar loss, number of correctly-predicted tokens (count), - and sum of per-token Shannon entropy (in nats) — all over the local batch. Raw sums are returned so callers can - reduce correctly across ranks. + `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]`: scalar loss, number of correctly-predicted + tokens (count), sum of per-token Shannon entropy (in nats), and number of valid (non-`-100`) target tokens — + all over the local batch. Raw sums are returned so callers can reduce correctly across ranks. """ if labels is None and shift_labels is None: raise ValueError("At least one of `labels` or `shift_labels` must be provided.") @@ -173,6 +174,7 @@ def _chunked_cross_entropy_loss( correct = hidden.new_zeros((), dtype=torch.float32) entropy_sum = hidden.new_zeros((), dtype=torch.float32) + n_valid_tensor = torch.tensor(n_valid, device=hidden.device, dtype=torch.long) if n_valid == 0: # Whole micro-batch masked (e.g. completion-only loss + truncation). Keep the loss connected # to the autograd graph through every trainable parameter so `.backward()` succeeds and DDP / @@ -180,7 +182,7 @@ def _chunked_cross_entropy_loss( loss = (hidden_states.float().sum() + lm_head_weight.float().sum()) * 0.0 if lm_head_bias is not None: loss = loss + lm_head_bias.float().sum() * 0.0 - return loss, correct, entropy_sum + return loss, correct, entropy_sum, n_valid_tensor loss = hidden.new_zeros((), dtype=torch.float32) @@ -207,7 +209,7 @@ def _chunked_cross_entropy_loss( if isinstance(num_items_in_batch, torch.Tensor): num_items_in_batch = num_items_in_batch.to(loss.device) loss = loss / num_items_in_batch - return loss, correct, entropy_sum + return loss, correct, entropy_sum, n_valid_tensor def _patch_chunked_ce_lm_head(model: torch.nn.Module, chunk_size: int) -> None: @@ -217,9 +219,10 @@ def _patch_chunked_ce_lm_head(model: torch.nn.Module, chunk_size: int) -> None: When `labels` (or pre-shifted `shift_labels`, for CP/SP) are provided, the patched forward calls `model.get_decoder()` directly, drops positions with `labels == -100` before the `lm_head` matmul, and computes the cross-entropy in chunks of `chunk_size` valid tokens. It returns a [`_ChunkedCELMHeadOutput`] with `loss` set, - `logits` set to `None`, and `num_correct_tokens` / `entropy_sum` populated over non-ignored tokens. For MoE models - with `output_router_logits=True`, the load-balancing auxiliary loss is added with the same coefficient and formula - as the model's own forward, keeping the chunked path numerically equivalent to the reference. + `logits` set to `None`, and `num_correct_tokens` / `entropy_sum` / `num_valid_tokens` populated over non-ignored + tokens. For MoE models with `output_router_logits=True`, the load-balancing auxiliary loss is added with the same + coefficient and formula as the model's own forward, keeping the chunked path numerically equivalent to the + reference. With both `labels` and `shift_labels` set to `None` the original forward runs unchanged, so generation and labels-free evaluation preserve any per-model logits post-processing (e.g. `logit_scale`, @@ -265,7 +268,7 @@ def _chunked_ce_forward( ) hidden_states = outputs.last_hidden_state - loss, num_correct_tokens, entropy_sum = _chunked_cross_entropy_loss( + loss, num_correct_tokens, entropy_sum, num_valid_tokens = _chunked_cross_entropy_loss( hidden_states, self.lm_head.weight, chunk_size, @@ -301,6 +304,7 @@ def _chunked_ce_forward( attentions=outputs.attentions, num_correct_tokens=num_correct_tokens, entropy_sum=entropy_sum, + num_valid_tokens=num_valid_tokens, aux_loss=aux_loss, ) @@ -1570,10 +1574,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Compute entropy if self.args.loss_type == "chunked_nll": - shift_labels = inputs["shift_labels"] if "shift_labels" in inputs else labels[..., 1:] - n_valid = self.accelerator.gather_for_metrics((shift_labels != -100).sum()).sum() + # Use `num_valid_tokens` from the patched forward rather than recomputing from `labels`. Prompt-learning + # PEFT (PromptTuning, P-Tuning) prepends `-100`-padded virtual tokens before delegating into the patched + # forward, so the valid-token count over the padded labels can differ from the un-padded `labels[..., 1:]` + # count by up to one per sequence; using the patched output keeps numerator and denominator aligned. + num_valid = self.accelerator.gather_for_metrics(outputs.num_valid_tokens).sum() entropy_sum = self.accelerator.gather_for_metrics(outputs.entropy_sum).sum() - entropy = (entropy_sum / n_valid).item() if n_valid > 0 else 0.0 + entropy = (entropy_sum / num_valid).item() if num_valid > 0 else 0.0 self._metrics[mode]["entropy"].append(entropy) elif not self.args.use_liger_kernel: # liger doesn't return logits with torch.no_grad(): @@ -1617,10 +1624,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self._metrics[mode]["num_tokens"] = [self._total_train_tokens] if self.args.loss_type == "chunked_nll": - shift_labels = inputs["shift_labels"] if "shift_labels" in inputs else labels[..., 1:] - n_valid = self.accelerator.gather_for_metrics((shift_labels != -100).sum()).sum() + num_valid = self.accelerator.gather_for_metrics(outputs.num_valid_tokens).sum() correct = self.accelerator.gather_for_metrics(outputs.num_correct_tokens).sum() - accuracy = (correct / n_valid).item() if n_valid > 0 else 0.0 + accuracy = (correct / num_valid).item() if num_valid > 0 else 0.0 self._metrics[mode]["mean_token_accuracy"].append(accuracy) elif self.args.use_liger_kernel: if hasattr(outputs, "token_accuracy") and outputs.token_accuracy is not None: From 30d1d7307fbea6d7cffc3b5f2e374c68e362e848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 29 Apr 2026 00:17:38 +0000 Subject: [PATCH 3/6] style --- tests/test_sft_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 0ba8d177a5c..eb9263d2b35 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -2732,8 +2732,8 @@ def test_forward_matches_reference_with_peft(self, peft_config_factory): def test_num_valid_tokens_with_prompt_learning_peft(self): """For prompt-learning PEFT (PromptTuning, P-Tuning), `PeftModel.forward` prepends `-100`-padded virtual tokens before delegating into the patched inner forward. The patched output's `num_valid_tokens` must reflect - the padded labels — when original `label[0] != -100`, it counts as a valid target paired with the last - virtual token's hidden state, so it must be included in the metric denominator to keep accuracy ≤ 1.""" + the padded labels — when original `label[0] != -100`, it counts as a valid target paired with the last virtual + token's hidden state, so it must be included in the metric denominator to keep accuracy ≤ 1.""" base = AutoModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32 ).to(torch_device) From 69c1b2f826c5d28951b47dd9cf7a638b4a1e1059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 29 Apr 2026 00:22:21 +0000 Subject: [PATCH 4/6] better --- tests/test_sft_trainer.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index eb9263d2b35..06da2b0d301 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -533,7 +533,7 @@ def test_train_chunked_nll_loss_peft(self): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + torch.testing.assert_close(param, new_param, msg=f"Parameter {n} has changed") elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @@ -2574,12 +2574,11 @@ class TestPatchChunkedCELMHead: CHUNK_SIZE = 5 # small, to exercise the chunk loop def _setup(self, model_id): - ref_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32).to(torch_device) + ref_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=torch_device) chunked_model = copy.deepcopy(ref_model) _patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE) B, S = 2, 16 - torch.manual_seed(42) input_ids = torch.randint(0, ref_model.config.vocab_size, (B, S), device=torch_device) labels = input_ids.clone() labels[:, :4] = -100 # prompt-like mask @@ -2610,13 +2609,12 @@ def test_forward_matches_reference_with_aux_loss(self, model_id): """MoE models with `output_router_logits=True` add `router_aux_loss_coef * load_balancing_loss` to the main loss. The chunked path must match the reference loss and expose `aux_loss`.""" ref_model = AutoModelForCausalLM.from_pretrained( - model_id, torch_dtype=torch.float32, output_router_logits=True - ).to(torch_device) + model_id, torch_dtype=torch.float32, output_router_logits=True, device_map=torch_device + ) chunked_model = copy.deepcopy(ref_model) _patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE) B, S = 2, 16 - torch.manual_seed(42) input_ids = torch.randint(0, ref_model.config.vocab_size, (B, S), device=torch_device) labels = input_ids.clone() labels[:, :4] = -100 @@ -2696,14 +2694,13 @@ def test_forward_matches_reference_with_peft(self, peft_config_factory): the unpatched PEFT reference for both LoRA-style (adapters live in the module tree) and prompt-learning (`PeftModel.forward` injects virtual tokens, then delegates into the patched inner forward).""" base = AutoModelForCausalLM.from_pretrained( - "trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32 - ).to(torch_device) + "trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32, device_map=torch_device + ) ref_model = get_peft_model(copy.deepcopy(base), peft_config_factory()) chunked_model = copy.deepcopy(ref_model) _patch_chunked_ce_lm_head(chunked_model.get_base_model(), chunk_size=self.CHUNK_SIZE) B, S = 2, 16 - torch.manual_seed(42) input_ids = torch.randint(0, base.config.vocab_size, (B, S), device=torch_device) labels = input_ids.clone() labels[:, :4] = -100 @@ -2735,14 +2732,13 @@ def test_num_valid_tokens_with_prompt_learning_peft(self): the padded labels — when original `label[0] != -100`, it counts as a valid target paired with the last virtual token's hidden state, so it must be included in the metric denominator to keep accuracy ≤ 1.""" base = AutoModelForCausalLM.from_pretrained( - "trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32 - ).to(torch_device) + "trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32, device_map=torch_device + ) peft_config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4) chunked_model = get_peft_model(base, peft_config) _patch_chunked_ce_lm_head(chunked_model.get_base_model(), chunk_size=self.CHUNK_SIZE) B, S = 2, 16 - torch.manual_seed(42) input_ids = torch.randint(0, base.config.vocab_size, (B, S), device=torch_device) labels = input_ids.clone() # all positions valid, including label[0] From 68f4f94c2d1dc8790c41c9d3084beeef16b552ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 29 Apr 2026 16:46:54 +0000 Subject: [PATCH 5/6] raise --- trl/trainer/sft_trainer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 359fc62a51d..5fb429da603 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1235,6 +1235,19 @@ def __init__( # `PeftModel.forward`. Prompt-learning variants need `PeftModel.forward` to run first (to inject # virtual tokens), then it delegates into the patched inner forward. target = model.get_base_model() if is_peft_model(model) else model + # The chunked path reads `lm_head.weight` directly, which would silently drop the adapter delta + # (and starve its parameters of gradients) if `lm_head` itself is a PEFT tuner layer. + if is_peft_model(model): + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(target.lm_head, BaseTunerLayer): + raise ValueError( + "`loss_type='chunked_nll'` is not supported when `lm_head` is wrapped by a PEFT " + "adapter (e.g. `target_modules='all-linear'` or explicitly including `'lm_head'`). " + "Either remove `lm_head` from `target_modules`, or switch to `loss_type='nll'`. If " + "this is a real use case for you, please open an issue at " + "https://github.com/huggingface/trl/issues." + ) _patch_chunked_ce_lm_head(target, chunk_size=_CHUNKED_LM_HEAD_CHUNK_SIZE) else: raise ValueError( From 9317f2f5ff7df97a013e12ab9ec2a0c2f359a471 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 5 May 2026 16:06:26 +0000 Subject: [PATCH 6/6] remove redundant num_valid calculation for chunked_nll loss type --- trl/trainer/sft_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 5fb429da603..2538382db43 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1632,7 +1632,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self._metrics[mode]["num_tokens"] = [self._total_train_tokens] if self.args.loss_type == "chunked_nll": - num_valid = self.accelerator.gather_for_metrics(outputs.num_valid_tokens).sum() correct = self.accelerator.gather_for_metrics(outputs.num_correct_tokens).sum() accuracy = (correct / num_valid).item() if num_valid > 0 else 0.0 self._metrics[mode]["mean_token_accuracy"].append(accuracy)