Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,29 @@ def test_sft_trainer_skip_prepare_dataset(self):
assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features
assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features

def test_sft_trainer_skip_prepare_dataset_with_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
remove_unused_columns=False,
packing=False,
dataset_kwargs={"skip_prepare_dataset": True},
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_dataset,
)
assert trainer.train_dataset.features == self.dummy_dataset.features

@requires_pil
def test_sft_trainer_llava(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
10 changes: 9 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,15 @@ def make_inputs_require_grad(module, input, output):
dataset_kwargs["add_special_tokens"] = False

if not args.packing:
if args.dataset_text_field is None and formatting_func is None:
# If we aren't skipping data preparation, then a dataset_text_field
# or formatting_func must be provided.
if (
args.dataset_text_field is None
and formatting_func is None
and dataset_kwargs is not None
and "skip_prepare_dataset" in dataset_kwargs
and dataset_kwargs["skip_prepare_dataset"]
):
raise ValueError(
"You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument."
)
Expand Down