Skip to content
155 changes: 139 additions & 16 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,43 @@ def test_train_chunked_nll_loss(self):
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"

@require_peft
def test_train_chunked_nll_loss_peft(self):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]

# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

# Initialize the trainer
training_args = SFTConfig(output_dir=self.tmp_dir, loss_type="chunked_nll", report_to="none")

trainer = SFTTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(),
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
torch.testing.assert_close(param, new_param, msg=f"Parameter {n} has changed")
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"

def test_train_moe_model_with_aux_loss(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
Expand Down Expand Up @@ -2410,22 +2447,24 @@ def _reference(hidden, weight, labels, num_items_in_batch=None):
def test_forward_matches_cross_entropy(self):
"""With no ignored tokens, chunked loss equals standard mean cross-entropy."""
hidden, weight, labels = self._inputs()
n_valid = (labels[..., 1:] != -100).sum()
loss_c, correct_c, ent_sum_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels)
expected_n_valid = (labels[..., 1:] != -100).sum()
loss_c, correct_c, ent_sum_c, n_valid_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels)
loss_r, acc_r, ent_r = self._reference(hidden, weight, labels)
torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(correct_c / n_valid, acc_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ent_sum_c / n_valid, ent_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(correct_c / n_valid_c, acc_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ent_sum_c / n_valid_c, ent_r, atol=1e-5, rtol=1e-5)
assert n_valid_c.item() == expected_n_valid.item()

def test_forward_ignore_index(self):
"""Ignored labels are excluded from loss, accuracy and entropy (matches F.cross_entropy)."""
hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3))
n_valid = (labels[..., 1:] != -100).sum()
loss_c, correct_c, ent_sum_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels)
expected_n_valid = (labels[..., 1:] != -100).sum()
loss_c, correct_c, ent_sum_c, n_valid_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels)
loss_r, acc_r, ent_r = self._reference(hidden, weight, labels)
torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(correct_c / n_valid, acc_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ent_sum_c / n_valid, ent_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(correct_c / n_valid_c, acc_r, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ent_sum_c / n_valid_c, ent_r, atol=1e-5, rtol=1e-5)
assert n_valid_c.item() == expected_n_valid.item()

def test_num_items_in_batch_reduction(self):
"""When num_items_in_batch is provided, loss is sum / num_items_in_batch."""
Expand Down Expand Up @@ -2469,12 +2508,13 @@ def test_all_ignored_returns_zero(self):
hidden, weight, labels = self._inputs(requires_grad=True)
bias = torch.zeros(self.V, dtype=torch.float32, requires_grad=True)
labels[:] = -100
loss, correct, ent_sum = _chunked_cross_entropy_loss(
loss, correct, ent_sum, n_valid = _chunked_cross_entropy_loss(
hidden, weight, self.CHUNK_SIZE, labels, lm_head_bias=bias
)
assert loss.item() == 0.0
assert correct.item() == 0.0
assert ent_sum.item() == 0.0
assert n_valid.item() == 0
assert not torch.isnan(loss)
# Backward must succeed even when n_valid == 0 (can happen with completion-only loss
# + truncation where a whole micro-batch is masked).
Expand All @@ -2486,15 +2526,16 @@ def test_all_ignored_returns_zero(self):
def test_shift_labels_matches_labels(self):
"""`shift_labels` path (CP/SP) must match the default `labels` path after external shifting."""
hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3))
loss_l, correct_l, ent_l = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels)
loss_l, correct_l, ent_l, n_valid_l = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels)
# Mimic what transformers does under CP/SP: pad labels with -100, then shift.
shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous()
loss_s, correct_s, ent_s = _chunked_cross_entropy_loss(
loss_s, correct_s, ent_s, n_valid_s = _chunked_cross_entropy_loss(
hidden, weight, self.CHUNK_SIZE, shift_labels=shift_labels
)
torch.testing.assert_close(loss_s, loss_l, atol=1e-6, rtol=1e-6)
torch.testing.assert_close(correct_s, correct_l, atol=1e-6, rtol=1e-6)
torch.testing.assert_close(ent_s, ent_l, atol=1e-6, rtol=1e-6)
assert n_valid_s.item() == n_valid_l.item()

def test_requires_labels_or_shift_labels(self):
"""Must provide at least one of `labels` or `shift_labels`."""
Expand Down Expand Up @@ -2536,12 +2577,11 @@ class TestPatchChunkedCELMHead:
CHUNK_SIZE = 5 # small, to exercise the chunk loop

def _setup(self, model_id):
ref_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32).to(torch_device)
ref_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map=torch_device)
chunked_model = copy.deepcopy(ref_model)
_patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE)

B, S = 2, 16
torch.manual_seed(42)
input_ids = torch.randint(0, ref_model.config.vocab_size, (B, S), device=torch_device)
labels = input_ids.clone()
labels[:, :4] = -100 # prompt-like mask
Expand Down Expand Up @@ -2572,13 +2612,12 @@ def test_forward_matches_reference_with_aux_loss(self, model_id):
"""MoE models with `output_router_logits=True` add `router_aux_loss_coef * load_balancing_loss`
to the main loss. The chunked path must match the reference loss and expose `aux_loss`."""
ref_model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float32, output_router_logits=True
).to(torch_device)
model_id, torch_dtype=torch.float32, output_router_logits=True, device_map=torch_device
)
chunked_model = copy.deepcopy(ref_model)
_patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE)

B, S = 2, 16
torch.manual_seed(42)
input_ids = torch.randint(0, ref_model.config.vocab_size, (B, S), device=torch_device)
labels = input_ids.clone()
labels[:, :4] = -100
Expand Down Expand Up @@ -2631,3 +2670,87 @@ def test_forward_without_labels_matches_reference(self):
ref_out = ref_model(input_ids=input_ids)
out = chunked_model(input_ids=input_ids)
torch.testing.assert_close(out.logits, ref_out.logits, atol=1e-5, rtol=1e-5)

@require_peft
@pytest.mark.filterwarnings("ignore:Model has `tie_word_embeddings=True`")
@pytest.mark.parametrize(
"peft_config_factory",
[
pytest.param(lambda: LoraConfig(r=4, target_modules=["q_proj", "v_proj"]), id="lora"),
pytest.param(
lambda: LoraConfig(r=4, target_modules=["q_proj", "v_proj"], modules_to_save=["lm_head"]),
id="lora+modules_to_save",
),
pytest.param(
lambda: PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prompt_tuning"
),
pytest.param(
lambda: PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prompt_encoder"
),
pytest.param(
lambda: PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4), id="prefix_tuning"
),
],
)
def test_forward_matches_reference_with_peft(self, peft_config_factory):
"""Patching the inner causal LM (`peft_model.get_base_model()`) must produce a forward whose loss matches
the unpatched PEFT reference for both LoRA-style (adapters live in the module tree) and prompt-learning
(`PeftModel.forward` injects virtual tokens, then delegates into the patched inner forward)."""
base = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32, device_map=torch_device
)
ref_model = get_peft_model(copy.deepcopy(base), peft_config_factory())
chunked_model = copy.deepcopy(ref_model)
_patch_chunked_ce_lm_head(chunked_model.get_base_model(), chunk_size=self.CHUNK_SIZE)

B, S = 2, 16
input_ids = torch.randint(0, base.config.vocab_size, (B, S), device=torch_device)
labels = input_ids.clone()
labels[:, :4] = -100
num_items = int((labels[..., 1:] != -100).sum())

ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)
out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)
torch.testing.assert_close(out.loss, ref_out.loss, atol=1e-5, rtol=1e-5)

ref_out.loss.backward()
out.loss.backward()
chunked_params = dict(chunked_model.named_parameters())
for name, ref_param in ref_model.named_parameters():
if not ref_param.requires_grad or ref_param.grad is None:
continue
torch.testing.assert_close(
chunked_params[name].grad,
ref_param.grad,
atol=1e-5,
rtol=1e-5,
msg=f"gradient mismatch on {name}",
)

@require_peft
@pytest.mark.filterwarnings("ignore:Model has `tie_word_embeddings=True`")
def test_num_valid_tokens_with_prompt_learning_peft(self):
"""For prompt-learning PEFT (PromptTuning, P-Tuning), `PeftModel.forward` prepends `-100`-padded virtual
tokens before delegating into the patched inner forward. The patched output's `num_valid_tokens` must reflect
the padded labels — when original `label[0] != -100`, it counts as a valid target paired with the last virtual
token's hidden state, so it must be included in the metric denominator to keep accuracy ≤ 1."""
base = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen3ForCausalLM", dtype=torch.float32, device_map=torch_device
)
peft_config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4)
chunked_model = get_peft_model(base, peft_config)
_patch_chunked_ce_lm_head(chunked_model.get_base_model(), chunk_size=self.CHUNK_SIZE)

B, S = 2, 16
input_ids = torch.randint(0, base.config.vocab_size, (B, S), device=torch_device)
labels = input_ids.clone() # all positions valid, including label[0]

out = chunked_model(input_ids=input_ids, labels=labels)

# `labels[..., 1:]` (un-padded, what compute_loss used to compute) excludes original `label[0]`,
# but the patched forward sees padded labels and counts `label[0]` as a valid target.
unpadded = int((labels[..., 1:] != -100).sum())
# One extra valid target per sequence (original `label[0]`).
assert out.num_valid_tokens.item() == unpadded + B
# Accuracy denominator from the patched output keeps numerator/denominator aligned, so accuracy ≤ 1.
assert out.num_correct_tokens.item() <= out.num_valid_tokens.item()
Loading
Loading