From bcd0efb55de10bcba9a835d994e2f09645a1eb6c Mon Sep 17 00:00:00 2001 From: Bram Vanroy <2779410+BramVanroy@users.noreply.github.com> Date: Thu, 11 Apr 2024 14:30:11 +0200 Subject: [PATCH] allow pre-tokenized datasets --- trl/trainer/sft_trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 2194353f6cd..af2e1b88f64 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -401,7 +401,16 @@ def _prepare_dataset( if dataset is None: raise ValueError("The dataset should not be None") + # If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is + # a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset + column_names = ( + dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None + ) + if column_names and "input_ids" in column_names: + return dataset + # check if torch dataset / dataloader and do nothing + # see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check if isinstance( dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset) ) and not isinstance(dataset, datasets.IterableDataset):