Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,17 +1003,16 @@ def test_skip_prepare_dataset_passes_truncation_to_text_collator(self):
assert trainer.data_collator.max_length == 16
assert trainer.data_collator.truncation_mode == "keep_end"

def test_skip_prepare_dataset_with_padding_free_and_max_length_raises(self):
def test_padding_free_without_packing_and_max_length_raises(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]")
training_args = SFTConfig(
output_dir=self.tmp_dir,
max_length=16,
padding_free=True,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)

with pytest.raises(ValueError, match="must be enforced during dataset preparation or packing"):
with pytest.raises(ValueError, match="max_length` is not enforced"):
Comment thread
albertvillanova marked this conversation as resolved.
Outdated
SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
Expand Down
23 changes: 9 additions & 14 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
maybe_convert_to_chatml,
pack_dataset,
prepare_multimodal_messages,
truncate_dataset,
)
from ..models import get_act_offloading_ctx_manager
from .base_trainer import _BaseTrainer
Expand Down Expand Up @@ -572,6 +571,8 @@ class SFTTrainer(_BaseTrainer):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model
and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
Custom collators must truncate sequences before padding; the trainer does not apply post-collation
truncation.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
Expand Down Expand Up @@ -913,19 +914,19 @@ def __init__(
)

# Dataset
if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset:
raise ValueError(
"When `padding_free=True` without packing, `max_length` is not enforced. Either enable packing "
"(e.g., `packing=True, packing_strategy='bfd'`), provide already truncated inputs, or set "
"`max_length=None`."
)
# Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
# preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead.
skip_prepare_dataset = (
args.dataset_kwargs is not None
and args.dataset_kwargs.get("skip_prepare_dataset", False)
or self._is_vision_dataset
)
if skip_prepare_dataset and self.padding_free and args.max_length is not None and not self._is_vision_dataset:
raise ValueError(
"When `padding_free=True`, `max_length` must be enforced during dataset preparation or packing, not in "
"the collator. Disable `skip_prepare_dataset`, provide already packed/truncated inputs, or set "
"`max_length=None`."
)
if not skip_prepare_dataset:
if self.completion_only_loss and formatting_func:
raise ValueError(
Expand Down Expand Up @@ -1188,7 +1189,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
**map_kwargs,
)

# Pack or truncate
# Pack
if packing:
if args.max_length is None:
raise ValueError("When packing is enabled, `max_length` can't be `None`.")
Expand All @@ -1210,12 +1211,6 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo

# Packing adds new column "seq_lengths" needed for document aware FlashAttention
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
dataset = truncate_dataset(
dataset, args.max_length, truncation_mode=args.truncation_mode, map_kwargs=map_kwargs
)
# For Liger kernel, ensure only the essential columns
if args.use_liger_kernel:
collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"}
Expand Down
Loading