diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 2194353f6cd..af2e1b88f64 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -401,7 +401,16 @@ def _prepare_dataset( if dataset is None: raise ValueError("The dataset should not be None") + # If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is + # a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset + column_names = ( + dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None + ) + if column_names and "input_ids" in column_names: + return dataset + # check if torch dataset / dataloader and do nothing + # see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check if isinstance( dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset) ) and not isinstance(dataset, datasets.IterableDataset):