Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,17 +1003,16 @@ def test_skip_prepare_dataset_passes_truncation_to_text_collator(self):
assert trainer.data_collator.max_length == 16
assert trainer.data_collator.truncation_mode == "keep_end"

def test_skip_prepare_dataset_with_padding_free_and_max_length_raises(self):
def test_padding_free_without_packing_and_max_length_raises(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]")
training_args = SFTConfig(
output_dir=self.tmp_dir,
max_length=16,
padding_free=True,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)

with pytest.raises(ValueError, match="must be enforced during dataset preparation or packing"):
with pytest.raises(ValueError, match="`max_length` is not enforced"):
SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
Expand Down
23 changes: 9 additions & 14 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
maybe_convert_to_chatml,
pack_dataset,
prepare_multimodal_messages,
truncate_dataset,
)
from ..models import get_act_offloading_ctx_manager
from .base_trainer import _BaseTrainer
Expand Down Expand Up @@ -572,6 +571,8 @@ class SFTTrainer(_BaseTrainer):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model
and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
Custom collators must truncate sequences before padding; the trainer does not apply post-collation
truncation.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
Expand Down Expand Up @@ -913,19 +914,19 @@ def __init__(
)

# Dataset
if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset:
raise ValueError(
"When `padding_free=True` without packing, `max_length` is not enforced. Either enable packing "
"(e.g., `packing=True, packing_strategy='bfd'`), provide already truncated inputs, or set "
"`max_length=None`."
)
# Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
# preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead.
skip_prepare_dataset = (
args.dataset_kwargs is not None
and args.dataset_kwargs.get("skip_prepare_dataset", False)
or self._is_vision_dataset
)
if skip_prepare_dataset and self.padding_free and args.max_length is not None and not self._is_vision_dataset:
raise ValueError(
"When `padding_free=True`, `max_length` must be enforced during dataset preparation or packing, not in "
"the collator. Disable `skip_prepare_dataset`, provide already packed/truncated inputs, or set "
"`max_length=None`."
)
if not skip_prepare_dataset:
if self.completion_only_loss and formatting_func:
raise ValueError(
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
**map_kwargs,
)

# Pack or truncate
# Pack
if packing:
if args.max_length is None:
raise ValueError("When packing is enabled, `max_length` can't be `None`.")
Expand All @@ -1213,12 +1214,6 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo

# Packing adds new column "seq_lengths" needed for document aware FlashAttention
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
dataset = truncate_dataset(
dataset, args.max_length, truncation_mode=args.truncation_mode, map_kwargs=map_kwargs
)
# For Liger kernel, ensure only the essential columns
if args.use_liger_kernel:
collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"}
Expand Down
Loading