Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0cf0afb
Enable chunked NLL loss with PEFT in SFT
qgallouedec Apr 28, 2026
b3abe56
fix for prompt tuning
qgallouedec Apr 29, 2026
30d1d73
style
qgallouedec Apr 29, 2026
69c1b2f
better
qgallouedec Apr 29, 2026
77ba54c
Merge branch 'main' into chunked_nll_peft
qgallouedec Apr 29, 2026
68f4f94
raise
qgallouedec Apr 29, 2026
248dbeb
Enable chunked NLL loss with VLM in SFT
qgallouedec Apr 29, 2026
a4cad2e
style
qgallouedec Apr 29, 2026
c81430d
Add VLM support to chunked cross-entropy loss tests
qgallouedec Apr 29, 2026
35903e2
more vlms
qgallouedec Apr 29, 2026
eeb4a37
rm docstring
qgallouedec Apr 29, 2026
7006f11
Merge branch 'main' into chunked_nll_peft
qgallouedec Apr 30, 2026
c0e822f
Merge branch 'chunked_nll_peft' into chunked-nll-vlm
qgallouedec Apr 30, 2026
f256f32
Merge branch 'main' into chunked_nll_peft
qgallouedec May 3, 2026
5764e15
Merge branch 'chunked_nll_peft' into chunked-nll-vlm
qgallouedec May 3, 2026
e9cb282
Merge branch 'main' into chunked_nll_peft
qgallouedec May 4, 2026
5bb1309
Merge branch 'chunked_nll_peft' into chunked-nll-vlm
qgallouedec May 4, 2026
412dacf
fix base model resolution for old transformers versions
qgallouedec May 4, 2026
a080fa3
update auxiliary loss calculation to use text_config parameters
qgallouedec May 4, 2026
93ea587
allow old transformers versions
qgallouedec May 4, 2026
c48bb10
Merge branch 'main' into chunked_nll_peft
qgallouedec May 5, 2026
cd2cf7d
Merge branch 'chunked_nll_peft' into chunked-nll-vlm
qgallouedec May 5, 2026
54da4eb
Merge branch 'main' into chunked-nll-vlm
qgallouedec May 6, 2026
00cb84b
Merge branch 'main' into chunked-nll-vlm
qgallouedec May 6, 2026
de46c96
remove duplicate + consistency
qgallouedec May 6, 2026
7010136
concistency
qgallouedec May 6, 2026
e46a5b0
Merge branch 'main' into chunked-nll-vlm
qgallouedec May 7, 2026
4205b3c
align loss type doc
qgallouedec May 7, 2026
9387319
a bit better testing
qgallouedec May 7, 2026
ec0cad7
Merge branch 'main' into chunked-nll-vlm
qgallouedec May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 221 additions & 18 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from datasets import load_dataset
from packaging.version import Version
from packaging.version import parse as parse_version
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from transformers.testing_utils import backend_empty_cache, torch_device
from transformers.utils import is_peft_available

Expand Down Expand Up @@ -476,40 +482,113 @@ def test_train_chunked_nll_loss(self):

@require_peft
def test_train_chunked_nll_loss_peft(self):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]

# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

# Initialize the trainer
training_args = SFTConfig(output_dir=self.tmp_dir, loss_type="chunked_nll", report_to="none")

trainer = SFTTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(),
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check the peft params have changed and the base model params have not changed
# Check that the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model params to be the same
torch.testing.assert_close(param, new_param, msg=f"Parameter {n} has changed.")
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize(
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
marks=[
pytest.mark.skipif(
Version(transformers.__version__) < Version("4.57.0"),
reason="Qwen3-VL series were introduced in transformers-4.57.0",
),
],
),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.2.0"),
reason="Qwen3.5 models were introduced in transformers-5.2.0",
),
),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.2.0"),
reason="Qwen3.5 models were introduced in transformers-5.2.0",
),
),
],
)
@require_vision
def test_train_chunked_nll_loss_vlm(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")

training_args = SFTConfig(
output_dir=self.tmp_dir,
loss_type="chunked_nll",
max_length=None, # for VLMs, truncating can remove image tokens, leading to errors
report_to="none",
)
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
torch.testing.assert_close(param, new_param, msg=f"Parameter {n} has changed")
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
# For some reason, these params are not updated. This is probably not related to TRL, but to
# the model itself. We should investigate this further, but for now we just skip these params.
# fmt: off
if (
model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we refacto this a bit ? any reasons they didn't change ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure; it's been an open question for a long time, but it's never been urgent enough for me to set aside time to investigate. My hunch is that the gradients reaching the vision tower are too weak for the weights to be updated, either because of the structure of the tiny model or because of the initialization values.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the refacto, I'd recommend keeping thing like this mostly because it's consistent with TestDPOTrainer.test_train_vlm and TestSFTTrainer.test_train_vlm, plus it explicitly shows which layers are problematic.
Although I agree it's no pretty

model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0
model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0
model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497
model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497
model_id == "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration" and "model.visual.deepstack_merger_list" in n
):
# fmt: on
continue
assert not torch.equal(param, new_param), f"Param {n} is not updated"

def test_train_moe_model_with_aux_loss(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
Expand Down Expand Up @@ -2187,6 +2266,43 @@ def test_train_offloading(self, model_name, packing):
]


_CHUNKED_CE_VLM_MODEL_IDS = [
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("4.57.0"),
reason="Qwen3-VL series were introduced in transformers-4.57.0",
),
),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.2.0"),
reason="Qwen3.5 models were introduced in transformers-5.2.0",
),
),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.2.0"),
reason="Qwen3.5 models were introduced in transformers-5.2.0",
),
),
]
Comment thread
cursor[bot] marked this conversation as resolved.


class TestChunkedCrossEntropyLoss:
B, S, H, V = 2, 8, 4, 16
CHUNK_SIZE = 3 # deliberately small to force multiple chunks and a partial final chunk
Expand Down Expand Up @@ -2365,6 +2481,19 @@ def _setup(self, model_id):
num_items = int((labels[..., 1:] != -100).sum())
return ref_model, chunked_model, input_ids, labels, num_items

def _setup_vlm(self, model_id):
ref_model = AutoModelForImageTextToText.from_pretrained(model_id, dtype=torch.float32, device_map=torch_device)
Comment thread
cursor[bot] marked this conversation as resolved.
chunked_model = copy.deepcopy(ref_model)
_patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE, is_vlm=True)

B, S = 2, 16
vocab_size = ref_model.config.text_config.vocab_size
input_ids = torch.randint(0, vocab_size, (B, S), device=torch_device)
labels = input_ids.clone()
labels[:, :4] = -100
num_items = int((labels[..., 1:] != -100).sum())
return ref_model, chunked_model, input_ids, labels, num_items

@pytest.mark.parametrize("model_id", _CHUNKED_CE_MODEL_IDS)
def test_forward_matches_reference(self, model_id):
ref_model, chunked_model, input_ids, labels, num_items = self._setup(model_id)
Expand Down Expand Up @@ -2423,13 +2552,87 @@ def test_backward_matches_reference(self, model_id):
)
# Base decoder gradients
for name, ref_param in ref_model.model.named_parameters():
if ref_param.grad is None:
continue
chunked_param = chunked_model.model.get_parameter(name)
torch.testing.assert_close(
chunked_param.grad, ref_param.grad, atol=1e-5, rtol=1e-5, msg=f"gradient mismatch on model.{name}"
chunked_grad = chunked_model.model.get_parameter(name).grad
ref_grad = ref_param.grad
assert (chunked_grad is None) == (ref_grad is None), f"grad presence mismatch on model.{name}"
if ref_grad is not None:
torch.testing.assert_close(
chunked_grad, ref_grad, atol=1e-5, rtol=1e-5, msg=f"gradient mismatch on model.{name}"
)

@pytest.mark.parametrize("model_id", _CHUNKED_CE_VLM_MODEL_IDS)
def test_forward_matches_reference_vlm(self, model_id):
ref_model, chunked_model, input_ids, labels, num_items = self._setup_vlm(model_id)

with torch.no_grad():
ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)
out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)

torch.testing.assert_close(out.loss, ref_out.loss, atol=1e-5, rtol=1e-5)
assert out.logits is None
assert out.num_correct_tokens is not None and out.num_correct_tokens.item() >= 0
assert out.entropy_sum is not None and out.entropy_sum.item() >= 0.0

@pytest.mark.parametrize(
"model_id",
[
pytest.param(
"trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.2.0"),
reason="Qwen3.5 models were introduced in transformers-5.2.0",
),
),
],
)
def test_forward_matches_reference_vlm_with_aux_loss(self, model_id):
ref_model = AutoModelForImageTextToText.from_pretrained(model_id, dtype=torch.float32, device_map=torch_device)
chunked_model = copy.deepcopy(ref_model)
_patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE, is_vlm=True)

# VLM MoE wrappers only read `output_router_logits` from forward kwargs (their `text_config` explicitly
# removes the attribute), so we have to pass it at call time on both paths.
B, S = 2, 16
input_ids = torch.randint(0, ref_model.config.text_config.vocab_size, (B, S), device=torch_device)
labels = input_ids.clone()
labels[:, :4] = -100
num_items = int((labels[..., 1:] != -100).sum())

with torch.no_grad():
ref_out = ref_model(
input_ids=input_ids, labels=labels, num_items_in_batch=num_items, output_router_logits=True
)
out = chunked_model(
input_ids=input_ids, labels=labels, num_items_in_batch=num_items, output_router_logits=True
)

torch.testing.assert_close(out.loss, ref_out.loss, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(out.aux_loss, ref_out.aux_loss, atol=1e-5, rtol=1e-5)

@pytest.mark.parametrize("model_id", _CHUNKED_CE_VLM_MODEL_IDS)
def test_backward_matches_reference_vlm(self, model_id):
ref_model, chunked_model, input_ids, labels, num_items = self._setup_vlm(model_id)

ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)
ref_out.loss.backward()

out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items)
out.loss.backward()

# lm_head gradient
torch.testing.assert_close(
chunked_model.lm_head.weight.grad, ref_model.lm_head.weight.grad, atol=1e-5, rtol=1e-5
)
# Multimodal-wrapper gradients (covers both vision tower and inner text decoder).
for name, ref_param in ref_model.model.named_parameters():
chunked_grad = chunked_model.model.get_parameter(name).grad
ref_grad = ref_param.grad
assert (chunked_grad is None) == (ref_grad is None), f"grad presence mismatch on model.{name}"
if ref_grad is not None:
torch.testing.assert_close(
chunked_grad, ref_grad, atol=1e-5, rtol=1e-5, msg=f"gradient mismatch on model.{name}"
)

def test_forward_without_labels_uses_original_path(self):
"""With labels=None the patched forward returns real logits (for generation / eval)."""
_, chunked_model, input_ids, _, _ = self._setup("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
Expand Down
12 changes: 7 additions & 5 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class SFTConfig(_BaseConfig):
[this paper](https://huggingface.co/papers/2508.05629).
- `"chunked_nll"`: same math as `"nll"`, but the `lm_head` projection is computed on non-ignored tokens
only (positions with `labels == -100` are dropped before the matmul) and the cross-entropy is processed
in chunks of tokens to reduce peak activation memory. Not compatible with `use_liger_kernel`, PEFT, or
VLM models. Under FSDP2, set `fsdp_reshard_after_forward false` in the accelerate config — the chunked
path otherwise re-gathers `lm_head.weight` per chunk during backward, adding noticeable wall-time.
in chunks of tokens to reduce peak activation memory. Not compatible with `use_liger_kernel`. Under
FSDP2, set `fsdp_reshard_after_forward false` in the accelerate config — the chunked path otherwise
re-gathers `lm_head.weight` per chunk during backward, adding noticeable wall-time.
activation_offloading (`bool`, *optional*, defaults to `False`):
Whether to offload the activations to the CPU.

Expand Down Expand Up @@ -266,8 +266,10 @@ class SFTConfig(_BaseConfig):
metadata={
"help": "Type of loss to use. Possible values are `'nll'` (negative log-likelihood, default), `'dft'` "
"(Dynamic Fine-Tuning, https://huggingface.co/papers/2508.05629), and `'chunked_nll'` (same math as "
"`'nll'` but skips the `'lm_head'` matmul on ignored tokens and chunks the CE to reduce peak memory; not "
"compatible with Liger, PEFT, or VLM)."
"`'nll'`, but the `lm_head` projection is computed on non-ignored tokens only and the cross-entropy is "
"processed in chunks of tokens to reduce peak activation memory. Not compatible with `use_liger_kernel`. "
"Under FSDP2, set `fsdp_reshard_after_forward false` in the accelerate config — the chunked path "
"otherwise re-gathers `lm_head.weight` per chunk during backward, adding noticeable wall-time."
},
)
activation_offloading: bool = field(
Expand Down
Loading
Loading