Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class SFTConfig(TrainingArguments):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
shuffle_dataset (`bool`, *optional*, defaults to `False`):
Whether to shuffle the dataset.
packing (`bool`, *optional*, defaults to `False`):
Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce
padding. Uses `max_length` to define sequence length.
Expand Down Expand Up @@ -197,6 +199,10 @@ class SFTConfig(TrainingArguments):
"sequence length."
},
)
shuffle_dataset: bool = field(
default=False,
metadata={"help": "Whether to shuffle the dataset."},
)
packing: bool = field(
default=False,
metadata={
Expand Down
8 changes: 8 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,11 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo

dataset = dataset.select_columns(columns)

# Shuffle the dataset before packing. When using wrapped packing, it's important to shuffle before
# packing as well to avoid correlations between sequences packed together.
if args.shuffle_dataset:
dataset = dataset.shuffle(seed=args.seed)

# Packing adds new column "seq_lengths" needed for document aware FlashAttention
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
elif args.max_length is not None:
Expand All @@ -1083,6 +1088,9 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
column_names = get_dataset_column_names(dataset)
dataset = dataset.select_columns(collator_expected_keys.intersection(column_names))

if args.shuffle_dataset:
dataset = dataset.shuffle(seed=args.seed)

return dataset

def _set_signature_columns_if_needed(self):
Expand Down
Loading