From 0265d383367bfabfe931bc71bbe5068103c86326 Mon Sep 17 00:00:00 2001 From: kschwethelm Date: Thu, 20 Nov 2025 16:33:32 +0100 Subject: [PATCH 1/3] fix(SFTTrainer): Support VLM processors in completion-only text training - Access EOS token from processing_class.tokenizer for VLMs - Handle nested list outputs from VLM processors in prompt-completion mode) --- trl/trainer/sft_trainer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index f446efdc63c..d1bdcebb810 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -935,9 +935,10 @@ def add_eos(example, eos_token): example["completion"] = example["completion"] + eos_token return example + eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token dataset = dataset.map( add_eos, - fn_kwargs={"eos_token": processing_class.eos_token}, + fn_kwargs={"eos_token": eos_token}, remove_columns="messages" if "messages" in column_names else None, # renamed to "text" **map_kwargs, ) @@ -988,6 +989,14 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"])[ "input_ids" ] + # Fix transformers inconsistency: for VLMs, processing_class returns lists of lists + # even for single examples, while for LLMs it returns lists of ints. + prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids + prompt_completion_ids = ( + prompt_completion_ids[0] + if isinstance(prompt_completion_ids[0], list) + else prompt_completion_ids + ) # Check if the tokenized prompt starts with the tokenized prompt+completion if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: From 3cd4938b0da9f696d9aef11c2bd368973743c047 Mon Sep 17 00:00:00 2001 From: kschwethelm Date: Thu, 20 Nov 2025 16:59:43 +0100 Subject: [PATCH 2/3] Added pytests for vlm text only data prompt completion --- tests/test_sft_trainer.py | 72 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 874d5304f2f..1cfda535186 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1562,6 +1562,78 @@ def test_train_vlm_text_only_data(self, model_id): else: assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @require_vision + def test_train_vlm_text_only_data_prompt_completion(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n.startswith("model.visual"): + assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" + else: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @require_vision + def test_train_vlm_text_only_data_prompt_completion_non_conversational(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n.startswith("model.visual"): + assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" + else: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + @require_peft def test_prompt_tuning(self): """Test that SFT works with Prompt Tuning.""" From 93db2345db7f171ceda1ad43756f1edfcaeecce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 21:32:14 +0000 Subject: [PATCH 3/3] parametrize --- tests/test_sft_trainer.py | 76 +++------------------------------------ 1 file changed, 4 insertions(+), 72 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 1cfda535186..7f7cfecbbc7 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1532,82 +1532,14 @@ def test_train_vlm_gemma_3n(self): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", ], ) - @require_vision - def test_train_vlm_text_only_data(self, model_id): - # Get the dataset - dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") - - # Initialize the trainer - training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") - trainer = SFTTrainer( - model=model_id, - args=training_args, - train_dataset=dataset, - ) - - # Save the initial parameters to compare them later - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - # Train the model - trainer.train() - - # Check that the training loss is not None - assert trainer.state.log_history[-1]["train_loss"] is not None - - # Check the params have changed - for n, param in previous_trainable_params.items(): - new_param = trainer.model.get_parameter(n) - if n.startswith("model.visual"): - assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" - else: - assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" - - @pytest.mark.parametrize( - "model_id", - [ - "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - ], - ) - @require_vision - def test_train_vlm_text_only_data_prompt_completion(self, model_id): - # Get the dataset - dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") - - # Initialize the trainer - training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") - trainer = SFTTrainer( - model=model_id, - args=training_args, - train_dataset=dataset, - ) - - # Save the initial parameters to compare them later - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - # Train the model - trainer.train() - - # Check that the training loss is not None - assert trainer.state.log_history[-1]["train_loss"] is not None - - # Check the params have changed - for n, param in previous_trainable_params.items(): - new_param = trainer.model.get_parameter(n) - if n.startswith("model.visual"): - assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" - else: - assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" - @pytest.mark.parametrize( - "model_id", - [ - "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - ], + "dataset_config", + ["conversational_language_modeling", "conversational_prompt_completion", "standard_prompt_completion"], ) @require_vision - def test_train_vlm_text_only_data_prompt_completion_non_conversational(self, model_id): + def test_train_vlm_text_only_data(self, model_id, dataset_config): # Get the dataset - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train") + dataset = load_dataset("trl-internal-testing/zen", dataset_config, split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")