Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ea2973d
initial commit
qgallouedec Aug 7, 2025
0f9aa4c
Merge branch 'main' into native-vlm-support
qgallouedec Aug 7, 2025
f72dd39
proper image token ids
qgallouedec Aug 7, 2025
3b78c35
fix tiny model
qgallouedec Aug 7, 2025
95d7767
consistency
qgallouedec Aug 7, 2025
3c5aab9
fix tiny model
qgallouedec Aug 7, 2025
d4a5f67
fix test
qgallouedec Aug 7, 2025
590f997
this should work
qgallouedec Aug 7, 2025
015fd2b
fix gemma
qgallouedec Aug 7, 2025
c124d73
Merge branch 'main' into native-vlm-support
sergiopaniego Aug 7, 2025
1b76e66
Merge branch 'main' into native-vlm-support
kashif Aug 7, 2025
533ba8c
dtype check and scripts update
sergiopaniego Aug 7, 2025
1b967c4
add vision requirement to test_train_vlm in SFTTrainerTester2
qgallouedec Aug 8, 2025
4f677fa
remove force option from push_to_hub in generate_tiny_models.py
qgallouedec Aug 8, 2025
d4a122b
add test case for tiny-Qwen2VLForConditionalGeneration in SFTTrainerT…
qgallouedec Aug 8, 2025
7207802
generate idefics3
qgallouedec Aug 9, 2025
5c53f48
update test
qgallouedec Aug 9, 2025
9cabca2
a lot better
qgallouedec Aug 9, 2025
aea083d
Merge branch 'main' into native-vlm-support
qgallouedec Aug 9, 2025
fda5c1e
Update trl/trainer/sft_trainer.py
qgallouedec Aug 9, 2025
fb203ff
Update scripts/generate_tiny_models.py
qgallouedec Aug 9, 2025
8d5ff49
clean test
qgallouedec Aug 9, 2025
df05be1
Add llava_instruct_mix.py dataset processing script
qgallouedec Aug 9, 2025
b9a01a2
doc
qgallouedec Aug 9, 2025
230d691
update doc
qgallouedec Aug 9, 2025
0ed5e80
fix llava mix dataset
qgallouedec Aug 9, 2025
329667c
fix doc
qgallouedec Aug 9, 2025
3b99cee
remove training vlm sft
qgallouedec Aug 9, 2025
8aa2381
imageS
qgallouedec Aug 9, 2025
4be3118
Update docs/source/sft_trainer.md
qgallouedec Aug 9, 2025
e06b2fb
Small docs nits
sergiopaniego Aug 11, 2025
fbcc78d
Updated sft_vlm.py example
sergiopaniego Aug 11, 2025
055a86a
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
15a2605
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
30b7a2d
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
dbc4b65
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
fa05ed2
style
qgallouedec Aug 12, 2025
3ae6682
ignore failing test
qgallouedec Aug 12, 2025
f9c1fec
Clarify behavior of `skip_prepare_dataset` for VLM models in SFTConfi…
qgallouedec Aug 12, 2025
089b732
Add documentation for DataCollator classes in SFTTrainer
qgallouedec Aug 12, 2025
e9c3b82
new tiny style
qgallouedec Aug 13, 2025
7c070b9
mnior + clean
qgallouedec Aug 13, 2025
e8ef1d3
fix tiny qwen2
qgallouedec Aug 13, 2025
eb22b9c
fix doc and comments
qgallouedec Aug 13, 2025
13ed3ee
final example cleaning
qgallouedec Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ class CustomScriptArguments(ScriptArguments):
# Configure training args
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

# Load dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train")
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--output_dir sft-llava-1.5-7b-hf \
--bf16 \
--bf16 True \
--torch_dtype bfloat16 \
--gradient_checkpointing

Expand Down Expand Up @@ -63,7 +63,6 @@
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

################
# Model, Tokenizer & Processor
Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/sft_vlm_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
--bf16 \
--bf16 True \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear \
Expand All @@ -47,7 +47,7 @@
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
--bf16 \
--bf16 True \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear
Expand Down Expand Up @@ -142,7 +142,6 @@ def main():
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

################
# Model, Tokenizer & Processor
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/sft_vlm_smol_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir sft-smol-vlm-hf \
--bf16 \
--bf16 True \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
Expand Down Expand Up @@ -70,7 +70,6 @@
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

################
# Model, Tokenizer & Processor
Expand Down
21 changes: 14 additions & 7 deletions scripts/generate_tiny_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
api = HfApi()


def push_to_hub(model, tokenizer, prefix=None, suffix=None):
def push_to_hub(model, tokenizer, prefix=None, suffix=None, force=False):
model_class_name = model.__class__.__name__
content = MODEL_CARD.format(model_class_name=model_class_name)
model_card = ModelCard(content)
Expand All @@ -108,7 +108,7 @@ def push_to_hub(model, tokenizer, prefix=None, suffix=None):
if suffix is not None:
repo_id += f"-{suffix}"

if api.repo_exists(repo_id):
if api.repo_exists(repo_id) and not force:
print(f"Model {repo_id} already exists, skipping")
else:
model.push_to_hub(repo_id)
Expand Down Expand Up @@ -297,19 +297,26 @@ def init_weights_tiny_model(model):
kwargs["projection_dim"] = 8
if config_class in [LlavaConfig, LlavaNextConfig, PaliGemmaConfig]:
vision_kwargs["projection_dim"] = 8
if config_class in [LlavaConfig, LlavaNextConfig]:
if config_class in [LlavaConfig, LlavaNextConfig, Gemma3Config]:
vision_kwargs["image_size"] = 336
vision_kwargs["patch_size"] = 14
vision_kwargs["patch_size"] = 20
processor.image_processor.size = {"height": 336, "width": 336}
if config_class in [Qwen2VLConfig, Qwen2_5_VLConfig]:
kwargs["vision_start_token_id"] = 151652
text_kwargs["rope_scaling"] = {"type": "mrope", "mrope_section": [1]}
kwargs["vision_end_token_id"] = 151653
kwargs["vision_token_id"] = 151654
kwargs["image_token_id"] = 151655
kwargs["vocab_size"] = len(processor.tokenizer.vocab)
text_kwargs["rope_scaling"] = {"type": "mrope", "mrope_section": [2]}
vision_kwargs["depth"] = 4
vision_kwargs["embed_dim"] = 64
if config_class in [Qwen2_5_VLConfig]:
vision_kwargs["out_hidden_size"] = 16

config = config_class(
text_config=dict(
vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder),
hidden_size=8,
hidden_size=16,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=2,
Expand All @@ -326,4 +333,4 @@ def init_weights_tiny_model(model):
**kwargs,
)
model = model_class(config)
push_to_hub(model, processor, "tiny")
push_to_hub(model, processor, "tiny", force=True)
28 changes: 28 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,3 +1472,31 @@ def test_train_with_torch_dtype(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@parameterized.expand(
[
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
]
)
def test_train_vlm(self, model_id):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
6 changes: 3 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def reward_func(completions, **kwargs):
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
processing_class ([`~transformers.PreTrainedTokenizerBase`] [`~transformers.ProcessorMixin`] or `None`, *optional*, defaults to `None`):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
Expand Down Expand Up @@ -534,9 +534,9 @@ def __init__(
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
warnings.warn(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
"The `model_init_kwargs` will be ignored."
)

# Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it
Expand Down
Loading
Loading