Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ea2973d
initial commit
qgallouedec Aug 7, 2025
0f9aa4c
Merge branch 'main' into native-vlm-support
qgallouedec Aug 7, 2025
f72dd39
proper image token ids
qgallouedec Aug 7, 2025
3b78c35
fix tiny model
qgallouedec Aug 7, 2025
95d7767
consistency
qgallouedec Aug 7, 2025
3c5aab9
fix tiny model
qgallouedec Aug 7, 2025
d4a5f67
fix test
qgallouedec Aug 7, 2025
590f997
this should work
qgallouedec Aug 7, 2025
015fd2b
fix gemma
qgallouedec Aug 7, 2025
c124d73
Merge branch 'main' into native-vlm-support
sergiopaniego Aug 7, 2025
1b76e66
Merge branch 'main' into native-vlm-support
kashif Aug 7, 2025
533ba8c
dtype check and scripts update
sergiopaniego Aug 7, 2025
1b967c4
add vision requirement to test_train_vlm in SFTTrainerTester2
qgallouedec Aug 8, 2025
4f677fa
remove force option from push_to_hub in generate_tiny_models.py
qgallouedec Aug 8, 2025
d4a122b
add test case for tiny-Qwen2VLForConditionalGeneration in SFTTrainerT…
qgallouedec Aug 8, 2025
7207802
generate idefics3
qgallouedec Aug 9, 2025
5c53f48
update test
qgallouedec Aug 9, 2025
9cabca2
a lot better
qgallouedec Aug 9, 2025
aea083d
Merge branch 'main' into native-vlm-support
qgallouedec Aug 9, 2025
fda5c1e
Update trl/trainer/sft_trainer.py
qgallouedec Aug 9, 2025
fb203ff
Update scripts/generate_tiny_models.py
qgallouedec Aug 9, 2025
8d5ff49
clean test
qgallouedec Aug 9, 2025
df05be1
Add llava_instruct_mix.py dataset processing script
qgallouedec Aug 9, 2025
b9a01a2
doc
qgallouedec Aug 9, 2025
230d691
update doc
qgallouedec Aug 9, 2025
0ed5e80
fix llava mix dataset
qgallouedec Aug 9, 2025
329667c
fix doc
qgallouedec Aug 9, 2025
3b99cee
remove training vlm sft
qgallouedec Aug 9, 2025
8aa2381
imageS
qgallouedec Aug 9, 2025
4be3118
Update docs/source/sft_trainer.md
qgallouedec Aug 9, 2025
e06b2fb
Small docs nits
sergiopaniego Aug 11, 2025
fbcc78d
Updated sft_vlm.py example
sergiopaniego Aug 11, 2025
055a86a
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
15a2605
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
30b7a2d
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
dbc4b65
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
fa05ed2
style
qgallouedec Aug 12, 2025
3ae6682
ignore failing test
qgallouedec Aug 12, 2025
f9c1fec
Clarify behavior of `skip_prepare_dataset` for VLM models in SFTConfi…
qgallouedec Aug 12, 2025
089b732
Add documentation for DataCollator classes in SFTTrainer
qgallouedec Aug 12, 2025
e9c3b82
new tiny style
qgallouedec Aug 13, 2025
7c070b9
mnior + clean
qgallouedec Aug 13, 2025
e8ef1d3
fix tiny qwen2
qgallouedec Aug 13, 2025
eb22b9c
fix doc and comments
qgallouedec Aug 13, 2025
13ed3ee
final example cleaning
qgallouedec Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ class CustomScriptArguments(ScriptArguments):
# Configure training args
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

# Load dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train")
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

################
# Model, Tokenizer & Processor
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft_vlm_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def main():
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

################
# Model, Tokenizer & Processor
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft_vlm_smol_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

################
# Model, Tokenizer & Processor
Expand Down
26 changes: 18 additions & 8 deletions scripts/generate_tiny_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
GptOssForCausalLM,
Idefics2Config,
Idefics2ForConditionalGeneration,
Idefics3Config,
Idefics3ForConditionalGeneration,
LlamaConfig,
LlamaForCausalLM,
LlamaForSequenceClassification,
Expand Down Expand Up @@ -98,7 +100,7 @@
api = HfApi()


def push_to_hub(model, tokenizer, prefix=None, suffix=None):
def push_to_hub(model, tokenizer, prefix=None, suffix=None, force=False):
model_class_name = model.__class__.__name__
content = MODEL_CARD.format(model_class_name=model_class_name)
model_card = ModelCard(content)
Expand All @@ -108,7 +110,7 @@ def push_to_hub(model, tokenizer, prefix=None, suffix=None):
if suffix is not None:
repo_id += f"-{suffix}"

if api.repo_exists(repo_id):
if api.repo_exists(repo_id) and not force:
print(f"Model {repo_id} already exists, skipping")
else:
model.push_to_hub(repo_id)
Expand Down Expand Up @@ -283,6 +285,7 @@ def init_weights_tiny_model(model):
("google/gemma-3-4b-it", Gemma3Config, Gemma3ForConditionalGeneration),
("google/paligemma-3b-pt-224", PaliGemmaConfig, PaliGemmaForConditionalGeneration),
("HuggingFaceM4/idefics2-8b", Idefics2Config, Idefics2ForConditionalGeneration),
("HuggingFaceM4/Idefics3-8B-Llama3", Idefics3Config, Idefics3ForConditionalGeneration),
("HuggingFaceTB/SmolVLM2-2.2B-Instruct", SmolVLMConfig, SmolVLMForConditionalGeneration),
("llava-hf/llava-1.5-7b-hf", LlavaConfig, LlavaForConditionalGeneration),
("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextConfig, LlavaNextForConditionalGeneration),
Expand All @@ -293,31 +296,38 @@ def init_weights_tiny_model(model):
kwargs = {}
text_kwargs = {}
vision_kwargs = {}
if config_class == PaliGemmaConfig:
if config_class in [PaliGemmaConfig]:
kwargs["projection_dim"] = 8
if config_class in [LlavaConfig, LlavaNextConfig, PaliGemmaConfig]:
vision_kwargs["projection_dim"] = 8
if config_class in [LlavaConfig, LlavaNextConfig]:
if config_class in [LlavaConfig, LlavaNextConfig, Gemma3Config]:
vision_kwargs["image_size"] = 336
vision_kwargs["patch_size"] = 14
vision_kwargs["patch_size"] = 20
processor.image_processor.size = {"height": 336, "width": 336}
if config_class in [Qwen2VLConfig, Qwen2_5_VLConfig]:
kwargs["vision_start_token_id"] = 151652
text_kwargs["rope_scaling"] = {"type": "mrope", "mrope_section": [1]}
kwargs["vision_end_token_id"] = 151653
kwargs["vision_token_id"] = 151654
kwargs["image_token_id"] = 151655
kwargs["vocab_size"] = len(processor.tokenizer.vocab)
text_kwargs["rope_scaling"] = {"type": "mrope", "mrope_section": [2]}
vision_kwargs["depth"] = 4
vision_kwargs["embed_dim"] = 64
if config_class in [Qwen2_5_VLConfig]:
vision_kwargs["out_hidden_size"] = 16

config = config_class(
text_config=dict(
vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder),
hidden_size=8,
hidden_size=16,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=2,
intermediate_size=32,
**text_kwargs,
),
vision_config=dict(
hidden_size=16,
hidden_size=2048,
Comment thread
qgallouedec marked this conversation as resolved.
Outdated
num_attention_heads=4,
num_hidden_layers=2,
intermediate_size=32,
Expand Down
183 changes: 93 additions & 90 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,11 @@
import unittest

import numpy as np
import pytest
import torch
from datasets import Dataset, Image, Sequence, load_dataset
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
LlavaForConditionalGeneration,
is_vision_available,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, is_vision_available
from transformers.testing_utils import require_flash_attn, require_peft, require_vision
from transformers.utils import is_peft_available

Expand Down Expand Up @@ -507,89 +502,6 @@ def test_no_packing(self):
self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"]))
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))

@require_vision
def test_skip_prepare_dataset(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
remove_unused_columns=False,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)

trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.dummy_vsft_instruction_dataset,
)
self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features)

def test_skip_prepare_dataset_with_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
remove_unused_columns=False,
packing=False,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)

trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.dummy_dataset,
)
self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features)

@require_vision
def test_llava(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
remove_unused_columns=False,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)
tiny_llava = LlavaForConditionalGeneration.from_pretrained(
"trl-internal-testing/tiny-LlavaForConditionalGeneration"
)
processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlavaForConditionalGeneration")

processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious
user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's
questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for
item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image'
%}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if
add_generation_prompt %}ASSISTANT: {% endif %}"""

def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]

# Tokenize the texts and process the images
batch = processor(images=images, text=texts, return_tensors="pt", padding=True)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

return batch

trainer = SFTTrainer(
model=tiny_llava,
args=training_args,
data_collator=collate_fn,
train_dataset=self.dummy_vsft_instruction_dataset,
)

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])


# This new tester aims to replace the first one at some point
class SFTTrainerTester2(unittest.TestCase):
Expand Down Expand Up @@ -1472,3 +1384,94 @@ def test_train_with_torch_dtype(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@parameterized.expand(
[
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",),
("trl-internal-testing/tiny-Idefics3ForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
("trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",),
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
("trl-internal-testing/tiny-SmolVLMForConditionalGeneration",),
]
)
@require_vision
def test_train_vlm(self, model_id):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir,
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
report_to="none",
)
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# For some reason, these params are not updated. This is probably not related to TRL, but to
# the model itself. We should investigate this further, but for now we just skip these params.
# fmt: off
if (
model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or
model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n
):
# fmt: on
continue
self.assertFalse(
torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
)

# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
@pytest.mark.slow
@require_vision
def test_train_vlm_gemma_3n(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir,
max_length=None,
per_device_train_batch_size=1,
gradient_checkpointing=True,
model_init_kwargs={"torch_dtype": "bfloat16"},
report_to="none",
)
trainer = SFTTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "model.vision_tower.timm_model.conv_stem.bn.weight" in n:
# This parameter is not updated, not sure why at this point.
continue
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
6 changes: 3 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def reward_func(completions, **kwargs):
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
processing_class ([`~transformers.PreTrainedTokenizerBase`] [`~transformers.ProcessorMixin`] or `None`, *optional*, defaults to `None`):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
Expand Down Expand Up @@ -534,9 +534,9 @@ def __init__(
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
warnings.warn(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
"The `model_init_kwargs` will be ignored."
)

# Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it
Expand Down
Loading
Loading