Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
maybe_convert_to_chatml,
pack_dataset,
prepare_multimodal_messages,
truncate_dataset,
)
from ..models import get_act_offloading_ctx_manager
from .base_trainer import _BaseTrainer
Expand Down Expand Up @@ -572,6 +571,8 @@ class SFTTrainer(_BaseTrainer):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model
and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
Custom collators must truncate sequences before padding; the trainer does not apply post-collation
truncation.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
Expand Down Expand Up @@ -923,8 +924,7 @@ def __init__(
if skip_prepare_dataset and self.padding_free and args.max_length is not None and not self._is_vision_dataset:
raise ValueError(
"When `padding_free=True`, `max_length` must be enforced during dataset preparation or packing, not in "
"the collator. Disable `skip_prepare_dataset`, provide already packed/truncated inputs, or set "
"`max_length=None`."
"the collator. Provide already packed/truncated inputs, or set `max_length=None`."
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
)
if not skip_prepare_dataset:
if self.completion_only_loss and formatting_func:
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
**map_kwargs,
)

# Pack or truncate
# Pack
if packing:
if args.max_length is None:
raise ValueError("When packing is enabled, `max_length` can't be `None`.")
Expand All @@ -1210,12 +1210,6 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo

# Packing adds new column "seq_lengths" needed for document aware FlashAttention
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
dataset = truncate_dataset(
dataset, args.max_length, truncation_mode=args.truncation_mode, map_kwargs=map_kwargs
)
# For Liger kernel, ensure only the essential columns
if args.use_liger_kernel:
collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"}
Expand Down
Loading