diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 6345d2a4bd7..d49c49ae48a 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -64,6 +64,8 @@ class SFTConfig(TrainingArguments): max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. + shuffle_dataset (`bool`, *optional*, defaults to `False`): + Whether to shuffle the dataset. packing (`bool`, *optional*, defaults to `False`): Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce padding. Uses `max_length` to define sequence length. @@ -197,6 +199,10 @@ class SFTConfig(TrainingArguments): "sequence length." }, ) + shuffle_dataset: bool = field( + default=False, + metadata={"help": "Whether to shuffle the dataset."}, + ) packing: bool = field( default=False, metadata={ diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c33781d608a..8e0ae98a480 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1071,6 +1071,11 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo dataset = dataset.select_columns(columns) + # Shuffle the dataset before packing. When using wrapped packing, it's important to shuffle before + # packing as well to avoid correlations between sequences packed together. + if args.shuffle_dataset: + dataset = dataset.shuffle(seed=args.seed) + # Packing adds new column "seq_lengths" needed for document aware FlashAttention dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) elif args.max_length is not None: @@ -1083,6 +1088,9 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo column_names = get_dataset_column_names(dataset) dataset = dataset.select_columns(collator_expected_keys.intersection(column_names)) + if args.shuffle_dataset: + dataset = dataset.shuffle(seed=args.seed) + return dataset def _set_signature_columns_if_needed(self):