Skip to content
94 changes: 94 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,43 @@ def test_train_chunked_nll_loss(self):
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"

@require_peft
def test_train_chunked_nll_loss_peft(self):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]

# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

# Initialize the trainer
training_args = SFTConfig(output_dir=self.tmp_dir, loss_type="chunked_nll", report_to="none")

trainer = SFTTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(),
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"

def test_train_moe_model_with_aux_loss(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
Expand Down Expand Up @@ -2628,3 +2665,60 @@ def test_forward_without_labels_matches_reference(self):
ref_out = ref_model(input_ids=input_ids)
out = chunked_model(input_ids=input_ids)
torch.testing.assert_close(out.logits, ref_out.logits, atol=1e-5, rtol=1e-5)

@require_peft
@pytest.mark.filterwarnings("ignore:Model has `tie_word_embeddings=True`")
@pytest.mark.parametrize(
"peft_config_factory",
[
pytest.param(lambda: LoraConfig(r=4, target_modules=["q_proj", "v_proj"]), id="lora"),
pytest.param(
lambda: LoraConfig(r=4, target_modules=["q_proj", "v_proj"], modules_to_save=["lm_head"]),
id="lora+modules_to_save",
),
pytest.param(
lambda: PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prompt_tuning"
),
pytest.param(
lambda: PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prompt_encoder"
),
pytest.param(
lambda: PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prefix_tuning"
),
],
)
def test_forward_matches_reference_with_peft(self, peft_config_factory):
"""Patching the inner causal LM (`peft_model.get_base_model()`) must produce a forward whose loss matches
the unpatched PEFT reference for both LoRA-style (adapters live in the module tree) and prompt-learning
(`PeftModel.forward` injects virtual tokens, then delegates into the patched inner forward)."""
base = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32
).to(torch_device)
ref_model = get_peft_model(copy.deepcopy(base), peft_config_factory())
chunked_model = copy.deepcopy(ref_model)
_patch_chunked_ce_lm_head(chunked_model.get_base_model(), chunk_size=self.CHUNK_SIZE)

B, S = 2, 16
torch.manual_seed(42)
input_ids = torch.randint(0, base.config.vocab_size, (B, S), device=torch_device)
labels = input_ids.clone()
labels[:, :4] = -100
num_items = int((labels[..., 1:] != -100).sum())

ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)
out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)
torch.testing.assert_close(out.loss, ref_out.loss, atol=1e-5, rtol=1e-5)

ref_out.loss.backward()
out.loss.backward()
chunked_params = dict(chunked_model.named_parameters())
for name, ref_param in ref_model.named_parameters():
if not ref_param.requires_grad or ref_param.grad is None:
continue
torch.testing.assert_close(
chunked_params[name].grad,
ref_param.grad,
atol=1e-5,
rtol=1e-5,
msg=f"gradient mismatch on {name}",
)
38 changes: 21 additions & 17 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,23 +212,24 @@ def _chunked_cross_entropy_loss(

def _patch_chunked_ce_lm_head(model: torch.nn.Module, chunk_size: int) -> None:
"""
Patch a causal LM so its `forward` computes the language modeling loss via [`_chunked_cross_entropy_loss`] when
`labels` are provided.

The patched forward calls the base decoder directly (`model.get_decoder()`) to obtain hidden states, skips the
`lm_head` matmul on positions with `labels == -100`, and computes the cross-entropy in chunks of `chunk_size` valid
tokens. It returns a [`_ChunkedCELMHeadOutput`] with `loss` set, `logits` set to `None`, and `token_accuracy` /
`entropy` fields set to the mean values over non-ignored tokens. Also accepts pre-shifted `shift_labels` in place
of `labels`, for the context / sequence parallelism path. When both are `None`, the original forward is invoked so
generation and labels-free evaluation preserve any per-model logits post-processing (e.g. `logit_scale`,
Patch a causal LM so `forward` computes the language modeling loss via [`_chunked_cross_entropy_loss`].

When `labels` (or pre-shifted `shift_labels`, for CP/SP) are provided, the patched forward calls
`model.get_decoder()` directly, drops positions with `labels == -100` before the `lm_head` matmul, and computes the
cross-entropy in chunks of `chunk_size` valid tokens. It returns a [`_ChunkedCELMHeadOutput`] with `loss` set,
`logits` set to `None`, and `num_correct_tokens` / `entropy_sum` populated over non-ignored tokens. For MoE models
with `output_router_logits=True`, the load-balancing auxiliary loss is added with the same coefficient and formula
as the model's own forward, keeping the chunked path numerically equivalent to the reference.

With both `labels` and `shift_labels` set to `None` the original forward runs unchanged, so generation and
labels-free evaluation preserve any per-model logits post-processing (e.g. `logit_scale`,
`final_logit_softcapping`).

For MoE models with `output_router_logits=True`, the load-balancing auxiliary loss is added to the main loss with
the same coefficient (`router_aux_loss_coef`) and formula (`load_balancing_loss_func`) used by the model's own
forward, so the chunked path remains numerically equivalent to the reference.
For PEFT, pass the inner causal LM (`peft_model.get_base_model()`) rather than the `PeftModel` wrapper, so that
prompt-learning variants (PromptTuning, PrefixTuning, PTuning) keep their virtual-token injection in
`PeftModel.forward` before delegating into the patched forward.

Not supported yet: VLM / multimodal models whose forward injects visual tokens outside the base decoder, and
PEFT-wrapped models.
Not supported yet: VLM / multimodal models whose forward injects visual tokens outside the base decoder.
"""
final_logit_softcapping = getattr(model.config, "final_logit_softcapping", None)
logit_scale = getattr(model.config, "logit_scale", 1.0)
Expand Down Expand Up @@ -1230,9 +1231,12 @@ def __init__(
# wraps the patched forward.
if self._is_vlm:
raise NotImplementedError("`loss_type='chunked_nll'` is not supported for VLM models yet.")
if peft_config is not None or is_peft_model(model):
raise NotImplementedError("`loss_type='chunked_nll'` is not supported with PEFT yet.")
_patch_chunked_ce_lm_head(model, chunk_size=_CHUNKED_LM_HEAD_CHUNK_SIZE)
# For PEFT, patch the inner causal LM rather than the `PeftModel` wrapper. LoRA / IA³ /
# `modules_to_save` adapters live in the module tree, so they're hit even when we bypass
# `PeftModel.forward`. Prompt-learning variants need `PeftModel.forward` to run first (to inject
# virtual tokens), then it delegates into the patched inner forward.
target = model.get_base_model() if is_peft_model(model) else model
Comment thread
qgallouedec marked this conversation as resolved.
_patch_chunked_ce_lm_head(target, chunk_size=_CHUNKED_LM_HEAD_CHUNK_SIZE)
else:
raise ValueError(
f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll', 'dft', and "
Expand Down
Loading