diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index f1e96d09e53..d423ec7943d 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1003,17 +1003,16 @@ def test_skip_prepare_dataset_passes_truncation_to_text_collator(self): assert trainer.data_collator.max_length == 16 assert trainer.data_collator.truncation_mode == "keep_end" - def test_skip_prepare_dataset_with_padding_free_and_max_length_raises(self): + def test_padding_free_without_packing_and_max_length_raises(self): dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]") training_args = SFTConfig( output_dir=self.tmp_dir, max_length=16, padding_free=True, - dataset_kwargs={"skip_prepare_dataset": True}, report_to="none", ) - with pytest.raises(ValueError, match="must be enforced during dataset preparation or packing"): + with pytest.raises(ValueError, match="`max_length` is not enforced"): SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 08dbe524133..77ab546774a 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -51,7 +51,6 @@ maybe_convert_to_chatml, pack_dataset, prepare_multimodal_messages, - truncate_dataset, ) from ..models import get_act_offloading_ctx_manager from .base_trainer import _BaseTrainer @@ -572,6 +571,8 @@ class SFTTrainer(_BaseTrainer): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. + Custom collators must truncate sequences before padding; the trainer does not apply post-collation + truncation. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and [prompt-completion](#prompt-completion) type. The format of the samples can be either: @@ -913,6 +914,12 @@ def __init__( ) # Dataset + if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset: + raise ValueError( + "When `padding_free=True` without packing, `max_length` is not enforced. Either enable packing " + "(e.g., `packing=True, packing_strategy='bfd'`), provide already truncated inputs, or set " + "`max_length=None`." + ) # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead. skip_prepare_dataset = ( @@ -920,12 +927,6 @@ def __init__( and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vision_dataset ) - if skip_prepare_dataset and self.padding_free and args.max_length is not None and not self._is_vision_dataset: - raise ValueError( - "When `padding_free=True`, `max_length` must be enforced during dataset preparation or packing, not in " - "the collator. Disable `skip_prepare_dataset`, provide already packed/truncated inputs, or set " - "`max_length=None`." - ) if not skip_prepare_dataset: if self.completion_only_loss and formatting_func: raise ValueError( @@ -1191,7 +1192,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo **map_kwargs, ) - # Pack or truncate + # Pack if packing: if args.max_length is None: raise ValueError("When packing is enabled, `max_length` can't be `None`.") @@ -1213,12 +1214,6 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo # Packing adds new column "seq_lengths" needed for document aware FlashAttention dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) - elif args.max_length is not None: - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Truncating {dataset_name} dataset" - dataset = truncate_dataset( - dataset, args.max_length, truncation_mode=args.truncation_mode, map_kwargs=map_kwargs - ) # For Liger kernel, ensure only the essential columns if args.use_liger_kernel: collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"}