From 4af1087b305c6a0be07798216040c2f6afdd91c4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 27 May 2024 23:59:46 -0500 Subject: [PATCH 1/3] Add test for skipping preproc if packing=True Signed-off-by: Alex-Brooks --- tests/test_sft_trainer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index dd896823e27..5fad8c6c946 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1156,6 +1156,29 @@ def test_sft_trainer_skip_prepare_dataset(self): assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features + def test_sft_trainer_skip_prepare_dataset_with_no_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + dataloader_drop_last=True, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + gradient_checkpointing=True, + remove_unused_columns=False, + packing=False, + dataset_kwargs={"skip_prepare_dataset": True}, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.dummy_vsft_instruction_dataset, + ) + assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features + @requires_pil def test_sft_trainer_llava(self): with tempfile.TemporaryDirectory() as tmp_dir: From 23a0bb2837817b63dea575c553a00e814eced21d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 28 May 2024 00:02:19 -0500 Subject: [PATCH 2/3] Allow skipping of validation for packing=True Signed-off-by: Alex-Brooks --- trl/trainer/sft_trainer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c3bb7c7c98b..a11c99128e5 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -319,7 +319,15 @@ def make_inputs_require_grad(module, input, output): dataset_kwargs["add_special_tokens"] = False if not args.packing: - if args.dataset_text_field is None and formatting_func is None: + # If we aren't skipping data preparation, then a dataset_text_field + # or formatting_func must be provided. + if ( + args.dataset_text_field is None + and formatting_func is None + and dataset_kwargs is not None + and "skip_prepare_dataset" in dataset_kwargs + and dataset_kwargs["skip_prepare_dataset"] + ): raise ValueError( "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument." ) From c38a586ae6d68f3101360c134c3695d3a65e9054 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 31 May 2024 11:13:41 -0600 Subject: [PATCH 3/3] Use dummy dataset in no packing preproc test Signed-off-by: Alex-Brooks --- tests/test_sft_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 5fad8c6c946..30729a6a413 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1175,9 +1175,9 @@ def test_sft_trainer_skip_prepare_dataset_with_no_packing(self): trainer = SFTTrainer( model=self.model_id, args=training_args, - train_dataset=self.dummy_vsft_instruction_dataset, + train_dataset=self.dummy_dataset, ) - assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features + assert trainer.train_dataset.features == self.dummy_dataset.features @requires_pil def test_sft_trainer_llava(self):